(wip) use select instead of a thread by server. Currently read and write seems to work properly, we lost some code (like more, next, ...) that never should be in server part

This commit is contained in:
nemunaire 2014-08-30 19:15:14 +02:00
parent 28c1ad088b
commit 81593a493b
8 changed files with 205 additions and 419 deletions

33
bot.py
View File

@ -21,6 +21,7 @@ from datetime import timedelta
import logging
from queue import Queue
import re
from select import select
import threading
import time
import uuid
@ -38,7 +39,7 @@ import response
logger = logging.getLogger("nemubot.bot")
class Bot:
class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals"""
@ -51,6 +52,8 @@ class Bot:
data_path -- Path to directory where store bot context data
"""
threading.Thread.__init__(self)
logger.info("Initiate nemubot v%s", __version__)
# External IP for accessing this bot
@ -90,6 +93,22 @@ class Bot:
self)
def run(self):
from server import _rlist, _wlist, _xlist
self.stop = False
while not self.stop:
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
for x in xl:
x.exception()
for w in wl:
w.write_select()
for r in rl:
for i in r.read():
self.receive_message(r, i)
def init_ctcp_capabilities(self):
"""Reset existing CTCP capabilities to default one"""
@ -277,17 +296,16 @@ class Bot:
c.start()
def add_server(self, node, nick, owner, realname, ssl=False):
def add_server(self, node, nick, owner, realname):
"""Add a new server to the context"""
srv = IRCServer(node, nick, owner, realname, ssl)
srv = IRCServer(node, nick, owner, realname)
srv.add_hook = lambda h: self.hooks.add_hook("irc_hook", h, self)
srv.add_networkbot = self.add_networkbot
srv.send_bot = lambda d: self.send_networkbot(srv, d)
srv.register_hooks()
#srv.register_hooks()
if srv.id not in self.servers:
self.servers[srv.id] = srv
if srv.autoconnect:
srv.launch(self.receive_message)
srv.open()
return True
else:
return False
@ -378,6 +396,9 @@ class Bot:
for srv in k:
self.servers[srv].disconnect()
self.stop = True
# Hooks cache
def create_cache(self, name):

View File

@ -46,9 +46,8 @@ class MessageConsumer:
def treat_in(self, context, msg):
"""Treat the input message"""
if msg.cmd == "PING":
self.srv.send_pong(msg.params[0])
self.srv.write("%s :%s" % ("PONG", msg.params[0]))
elif hasattr(msg, "receivers"):
msg.receivers = [ receiver for receiver in msg.receivers if self.srv.accepted_channel(receiver) ]
if msg.receivers:
# All messages
context.treat_pre(msg, self.srv)
@ -62,21 +61,29 @@ class MessageConsumer:
if r is not None: self.treat_out(context, r)
elif isinstance(res, response.Response):
# Define the destination server
if (res.server is not None and
isinstance(res.server, str) and res.server in context.servers):
res.server = context.servers[res.server]
if (res.server is not None and
not isinstance(res.server, server.Server)):
# Define the destination server
to_server = None
if res.server is None:
to_server = self.srv
res.server = self.srv.id
elif isinstance(res.server, str) and res.server in context.servers:
to_server = context.servers[res.server]
if to_server is None:
logger.error("The server defined in this response doesn't "
"exist: %s", res.server)
res.server = None
if res.server is None:
res.server = self.srv
return False
# Sent the message only if treat_post authorize it
if context.treat_post(res):
res.server.send_response(res, self.data)
if type(res.channel) != list:
res.channel = [ res.channel ]
for channel in res.channel:
if channel != to_server.nick:
to_server.write("%s %s :%s" % ("PRIVMSG", channel, res.get_message()))
else:
channel = res.sender
to_server.write("%s %s :%s" % ("NOTICE" if res.is_ctcp else "PRIVMSG", channel, res.get_message()))
elif res is not None:
logger.error("Unrecognized response type: %s", res)
@ -98,7 +105,7 @@ class MessageConsumer:
self.treat_out(context, res)
# Inform that the message has been treated
self.srv.msg_treated(self.data)
#self.srv.msg_treated(self.data)

View File

@ -68,6 +68,7 @@ if __name__ == "__main__":
print ("Nemubot v%s ready, my PID is %i!" % (bot.__version__,
os.getpid()))
context.start()
while prmpt.run(context):
try:
# Reload context

View File

@ -69,8 +69,7 @@ def load_file(filename, context):
nick = server["nick"] if server.hasAttribute("nick") else config["nick"]
owner = server["owner"] if server.hasAttribute("owner") else config["owner"]
realname = server["realname"] if server.hasAttribute("realname") else config["realname"]
if context.add_server(server, nick, owner, realname,
server.hasAttribute("ssl")):
if context.add_server(server, nick, owner, realname):
print("Server `%s:%s' successfully added." %
(server["server"], server["port"]))
else:

View File

@ -42,6 +42,13 @@ class Response:
self.set_sender(sender)
self.count = count
@property
def receivers(self):
if type(self.channel) is list:
return self.channel
else:
return [ self.channel ]
@property
def content(self):
#FIXME: error when messages in self.messages are list!

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Nemubot is a modulable IRC bot, built around XML configuration files.
# Copyright (C) 2012 Mercier Pierre-Olivier
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2014 nemunaire
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,272 +16,32 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import errno
import os
import ssl
import socket
import threading
import traceback
from channel import Channel
from server.DCC import DCC
from hooks import Hook
import message
import server
import xmlparser
from server.socket import SocketServer
class IRCServer(server.Server):
"""Class to interact with an IRC server"""
class IRCServer(SocketServer):
def __init__(self, node, nick, owner, realname, ssl=False):
"""Initialize an IRC server
Arguments:
node -- server node from XML configuration
nick -- nick used by the bot on this server
owner -- nick used by the bot owner on this server
realname -- string used as realname on this server
ssl -- require SSL?
"""
self.node = node
server.Server.__init__(self)
def __init__(self, node, nick, owner, realname):
SocketServer.__init__(self,
node["host"],
node["port"],
node["password"],
node.hasAttribute("ssl") and node["ssl"].lower() == "true")
self.nick = nick
self.owner = owner
self.realname = realname
self.ssl = ssl
self.id = "TODO"
# Listen private messages?
self.listen_nick = True
self.dcc_clients = dict()
self.channels = dict()
for chn in self.node.getNodes("channel"):
chan = Channel(chn["name"], chn["password"])
self.channels[chan.name] = chan
@property
def host(self):
"""Return the server hostname"""
if self.node is not None and self.node.hasAttribute("server"):
return self.node["server"]
else:
return "localhost"
@property
def port(self):
"""Return the connection port used on this server"""
if self.node is not None and self.node.hasAttribute("port"):
return self.node.getInt("port")
else:
return "6667"
@property
def password(self):
"""Return the password used to connect to this server"""
if self.node is not None and self.node.hasAttribute("password"):
return self.node["password"]
else:
return None
@property
def allow_all(self):
"""If True, treat message from all channels, not only listed one"""
return (self.node is not None and self.node.hasAttribute("allowall")
and self.node["allowall"] == "true")
@property
def autoconnect(self):
"""Autoconnect the server when added"""
if self.node is not None and self.node.hasAttribute("autoconnect"):
value = self.node["autoconnect"].lower()
return value != "no" and value != "off" and value != "false"
else:
return False
@property
def id(self):
"""Gives the server identifiant"""
return self.host + ":" + str(self.port)
def register_hooks(self):
self.add_hook(Hook(self.evt_channel, "JOIN"))
self.add_hook(Hook(self.evt_channel, "PART"))
self.add_hook(Hook(self.evt_server, "NICK"))
self.add_hook(Hook(self.evt_server, "QUIT"))
self.add_hook(Hook(self.evt_channel, "332"))
self.add_hook(Hook(self.evt_channel, "353"))
def evt_server(self, msg, srv):
for chan in self.channels:
self.channels[chan].treat(msg.cmd, msg)
def evt_channel(self, msg, srv):
if msg.receivers is not None:
for receiver in msg.receivers:
if receiver in self.channels:
self.channels[receiver].treat(msg.cmd, msg)
def accepted_channel(self, chan, sender=None):
"""Return True if the channel (or the user) is authorized"""
return (self.allow_all or
(chan in self.channels and (sender is None or sender in self.channels[chan].people)) or
(self.listen_nick and chan == self.nick))
def join(self, chan, password=None, force=False):
"""Join a channel"""
if force or (chan is not None and
self.connected and chan not in self.channels):
self.channels[chan] = Channel(chan, password)
if password is not None:
self.s.send(("JOIN %s %s\r\n" % (chan, password)).encode())
else:
self.s.send(("JOIN %s\r\n" % chan).encode())
def _open(self):
if SocketServer._open(self):
if self.password is not None:
self.write("PASS :" + self.password)
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.nick, self.host, self.realname))
return True
else:
return False
return False
def leave(self, chan):
"""Leave a channel"""
if chan is not None and self.connected and chan in self.channels:
if isinstance(chan, list):
for c in chan:
self.leave(c)
else:
self.s.send(("PART %s\r\n" % self.channels[chan].name).encode())
del self.channels[chan]
return True
else:
return False
# Main loop
def run(self):
if not self.connected:
self.s = socket.socket() #Create the socket
if self.ssl:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.s = ctx.wrap_socket(self.s)
try:
self.s.connect((self.host, self.port)) #Connect to server
except socket.error as e:
self.s = None
self.logger.critical("Unable to connect to %s:%d: %s",
self.host, self.port, os.strerror(e.errno))
return
self.stopping.clear()
if self.password != None:
self.s.send(b"PASS " + self.password.encode () + b"\r\n")
self.s.send(("NICK %s\r\n" % self.nick).encode ())
self.s.send(("USER %s %s bla :%s\r\n" % (self.nick, self.host,
self.realname)).encode())
raw = self.s.recv(1024)
if not raw:
self.logger.critical("Unable to connect to %s:%d", self.host, self.port)
return
self.connected = True
self.logger.info("Connection to %s:%d completed", self.host, self.port)
if len(self.channels) > 0:
for chn in self.channels.keys():
self.join(self.channels[chn].name,
self.channels[chn].password, force=True)
readbuffer = b'' #Here we store all the messages from server
try:
while not self.stop:
readbuffer = readbuffer + raw
temp = readbuffer.split(b'\n')
readbuffer = temp.pop()
for line in temp:
self.treat_msg(line)
raw = self.s.recv(1024) #recieve server messages
except socket.error:
pass
if self.connected:
self.s.close()
self.connected = False
if self.closing_event is not None:
self.closing_event()
self.logger.info("Server `%s' successfully stopped.", self.id)
self.stopping.set()
# Rearm Thread
threading.Thread.__init__(self)
# Overwritted methods
def disconnect(self):
"""Close the socket with the server and all DCC client connections"""
#Close all DCC connection
clts = [c for c in self.dcc_clients]
for clt in clts:
self.dcc_clients[clt].disconnect()
return server.Server.disconnect(self)
# Abstract methods
def send_pong(self, cnt):
"""Send a PONG command to the server with argument cnt"""
self.s.send(("PONG %s\r\n" % cnt).encode())
def msg_treated(self, origin):
"""Do nothing; here for implement abstract class"""
pass
def send_dcc(self, msg, to):
"""Send a message through DCC connection"""
if msg is not None and to is not None:
realname = to.split("!")[1]
if realname not in self.dcc_clients.keys():
d = DCC(self, to)
self.dcc_clients[realname] = d
self.dcc_clients[realname].send_dcc(msg)
def send_msg_final(self, channel, line, cmd="PRIVMSG", endl="\r\n"):
"""Send a message without checks or format"""
#TODO: add something for post message treatment here
if channel == self.nick:
self.logger.warn("Nemubot talks to himself: %s", line, stack_info=True)
if line is not None and channel is not None:
if self.s is None:
self.logger.warn("Attempt to send message on a non connected server: %s: %s", self.id, line, stack_info=True)
elif len(line) < 442:
self.s.send(("%s %s :%s%s" % (cmd, channel, line, endl)).encode ())
else:
self.logger.warn("Message truncated due to size (%d ; max : 442) : %s", len(line), line, stack_info=True)
self.s.send (("%s %s :%s%s" % (cmd, channel, line[0:442]+"<…>", endl)).encode ())
def send_msg_usr(self, user, msg):
"""Send a message to a user instead of a channel"""
if user is not None and user[0] != "#":
realname = user.split("!")[1]
if realname in self.dcc_clients or user in self.dcc_clients:
self.send_dcc(msg, user)
else:
for line in msg.split("\n"):
if line != "":
self.send_msg_final(user.split('!')[0], msg)
def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
"""Send a message to a channel"""
if self.accepted_channel(channel):
server.Server.send_msg(self, channel, msg, cmd, endl)
def send_msg_verified(self, sender, channel, msg, cmd = "PRIVMSG", endl = "\r\n"):
"""Send a message to a channel, only if the source user is on this channel too"""
if self.accepted_channel(channel, sender):
self.send_msg_final(channel, msg, cmd, endl)
def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"):
"""Send a message to all channels on this server"""
for channel in self.channels.keys():
self.send_msg(channel, msg, cmd, endl)
def _close(self):
self.write("QUIT")
SocketServer._close(self)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Nemubot is a modulable IRC bot, built around XML configuration files.
# Copyright (C) 2012 Mercier Pierre-Olivier
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2014 nemunaire
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,156 +16,70 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import socket
import threading
import queue
class Server(threading.Thread):
def __init__(self, socket = None):
self.stop = False
self.stopping = threading.Event()
self.s = socket
self.connected = self.s is not None
self.closing_event = None
# Lists for select
_rlist = []
_wlist = []
_xlist = []
self.moremessages = dict()
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase):
self.logger = logging.getLogger("nemubot.server." + self.id)
"""An abstract server: handle communication with an IM server"""
threading.Thread.__init__(self)
def __init__(self, send_callback=None):
"""Initialize an abstract server
def isDCC(self, to=None):
return to is not None and to in self.dcc_clients
Keyword argument:
send_callback -- Callback when developper want to send a message
"""
@property
def ip(self):
"""Convert common IP representation to little-endian integer representation"""
sum = 0
if self.node.hasAttribute("ip"):
ip = self.node["ip"]
self.logger = logging.getLogger("nemubot.server.TODO")
self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else:
#TODO: find the external IP
ip = "0.0.0.0"
for b in ip.split("."):
sum = 256 * sum + int(b)
return sum
self._send_callback = self._write_select
def toIP(self, input):
"""Convert little-endian int to IPv4 adress"""
ip = ""
for i in range(0,4):
mod = input % 256
ip = "%d.%s" % (mod, ip)
input = (input - mod) / 256
return ip[:len(ip) - 1]
@property
def id(self):
"""Gives the server identifiant"""
raise NotImplemented()
def open(self):
"""Generic open function that register the server un _rlist in case of successful _open"""
if self._open():
_rlist.append(self)
def accepted_channel(self, msg, sender=None):
return True
def msg_treated(self, origin):
"""Action done on server when a message was treated"""
raise NotImplemented()
def close(self):
"""Generic close function that register the server un _{r,w,x}list in case of successful _close"""
if self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
def send_response(self, res, origin):
"""Analyse a Response and send it"""
if type(res.channel) != list:
res.channel = [ res.channel ]
for channel in res.channel:
if channel != self.nick:
self.send_msg(channel, res.get_message())
else:
channel = res.sender
self.send_msg_usr(channel, res.get_message(), "NOTICE" if res.is_ctcp else "PRIVMSG")
def write(self, message):
"""Send a message to the server using send_callback"""
self._send_callback(message)
if not res.alone:
if hasattr(self, "send_bot"):
self.send_bot("NOMORE %s" % res.channel)
self.moremessages[channel] = res
def write_select(self):
"""Internal function used by the select function"""
try:
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
_wlist.remove(self)
def send_ctcp(self, to, msg, cmd="NOTICE", endl="\r\n"):
"""Send a message as CTCP response"""
if msg is not None and to is not None:
for line in msg.split("\n"):
if line != "":
self.send_msg_final(to.split("!")[0], "\x01" + line + "\x01", cmd, endl)
except queue.Empty:
pass
def send_dcc(self, msg, to):
"""Send a message through DCC connection"""
raise NotImplemented()
def send_msg_final(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
"""Send a message without checks or format"""
raise NotImplemented()
def send_msg_usr(self, user, msg):
"""Send a message to a user instead of a channel"""
raise NotImplemented()
def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
"""Send a message to a channel"""
if msg is not None:
for line in msg.split("\n"):
if line != "":
self.send_msg_final(channel, line, cmd, endl)
def send_msg_verified(self, sender, channel, msg, cmd="PRIVMSG", endl="\r\n"):
"""A more secure way to send messages"""
raise NotImplemented()
def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"):
"""Send a message to all channels on this server"""
raise NotImplemented()
def disconnect(self):
"""Close the socket with the server"""
if self.connected:
self.stop = True
try:
self.s.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.stopping.wait()
return True
else:
return False
def kill(self):
"""Just stop the main loop, don't close the socket directly"""
if self.connected:
self.stop = True
self.connected = False
#Send a message in order to close the socket
try:
self.s.send(("Bye!\r\n").encode ())
except:
pass
self.stopping.wait()
return True
else:
return False
def launch(self, receive_action, verb=True):
"""Connect to the server if it is no yet connected"""
self._receive_action = receive_action
if not self.connected:
self.stop = False
self.logger.info("Entering main loop for server")
try:
self.start()
except RuntimeError:
pass
elif verb:
print (" Already connected.")
def treat_msg(self, line, private=False):
self._receive_action(self, line, private)
def run(self):
raise NotImplemented()
def _write_select(self, message):
"""Send a message to the server safely through select"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to Queue", message)
if self not in _wlist:
_wlist.append(self)

77
server/socket.py Normal file
View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2014 nemunaire
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import ssl
import socket
from server import AbstractServer
class SocketServer(AbstractServer):
def __init__(self, host, port=6667, password=None, ssl=False):
AbstractServer.__init__(self)
self.host = host
self.port = int(port)
self.password = password
self.ssl = ssl
self.socket = None
self.readbuffer = b''
def fileno(self):
return self.socket.fileno() if self.socket else None
def _open(self):
# Create the socket
self.socket = socket.socket()
# Wrap the socket for SSL
if self.ssl:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
try:
self.socket.connect((self.host, self.port)) #Connect to server
self.logger.info("Connected to %s:%d", self.host, self.port)
except socket.error as e:
self.socket = None
self.logger.critical("Unable to connect to %s:%d: %s",
self.host, self.port, os.strerror(e.errno))
return False
return True
def _close(self):
if self.socket is not None:
self.socket.shutdown(SHUT_RDWR)
self.socket.close()
self.socket = None
def _write(self, cnt):
self.socket.send(cnt)
def format(self, txt):
return txt.encode() + b'\r\n'
def read(self):
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line