diff --git a/bot.py b/bot.py
index 2ac55b1..980f1bb 100644
--- a/bot.py
+++ b/bot.py
@@ -21,6 +21,7 @@ from datetime import timedelta
import logging
from queue import Queue
import re
+from select import select
import threading
import time
import uuid
@@ -38,7 +39,7 @@ import response
logger = logging.getLogger("nemubot.bot")
-class Bot:
+class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals"""
@@ -51,6 +52,8 @@ class Bot:
data_path -- Path to directory where store bot context data
"""
+ threading.Thread.__init__(self)
+
logger.info("Initiate nemubot v%s", __version__)
# External IP for accessing this bot
@@ -90,6 +93,22 @@ class Bot:
self)
+ def run(self):
+ from server import _rlist, _wlist, _xlist
+
+ self.stop = False
+ while not self.stop:
+ rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
+
+ for x in xl:
+ x.exception()
+ for w in wl:
+ w.write_select()
+ for r in rl:
+ for i in r.read():
+ self.receive_message(r, i)
+
+
def init_ctcp_capabilities(self):
"""Reset existing CTCP capabilities to default one"""
@@ -277,17 +296,16 @@ class Bot:
c.start()
- def add_server(self, node, nick, owner, realname, ssl=False):
+ def add_server(self, node, nick, owner, realname):
"""Add a new server to the context"""
- srv = IRCServer(node, nick, owner, realname, ssl)
+ srv = IRCServer(node, nick, owner, realname)
srv.add_hook = lambda h: self.hooks.add_hook("irc_hook", h, self)
srv.add_networkbot = self.add_networkbot
srv.send_bot = lambda d: self.send_networkbot(srv, d)
- srv.register_hooks()
+ #srv.register_hooks()
if srv.id not in self.servers:
self.servers[srv.id] = srv
- if srv.autoconnect:
- srv.launch(self.receive_message)
+ srv.open()
return True
else:
return False
@@ -378,6 +396,9 @@ class Bot:
for srv in k:
self.servers[srv].disconnect()
+ self.stop = True
+
+
# Hooks cache
def create_cache(self, name):
diff --git a/consumer.py b/consumer.py
index 9ad8304..5d41891 100644
--- a/consumer.py
+++ b/consumer.py
@@ -46,9 +46,8 @@ class MessageConsumer:
def treat_in(self, context, msg):
"""Treat the input message"""
if msg.cmd == "PING":
- self.srv.send_pong(msg.params[0])
+ self.srv.write("%s :%s" % ("PONG", msg.params[0]))
elif hasattr(msg, "receivers"):
- msg.receivers = [ receiver for receiver in msg.receivers if self.srv.accepted_channel(receiver) ]
if msg.receivers:
# All messages
context.treat_pre(msg, self.srv)
@@ -62,21 +61,29 @@ class MessageConsumer:
if r is not None: self.treat_out(context, r)
elif isinstance(res, response.Response):
- # Define the destination server
- if (res.server is not None and
- isinstance(res.server, str) and res.server in context.servers):
- res.server = context.servers[res.server]
- if (res.server is not None and
- not isinstance(res.server, server.Server)):
+ # Define the destination server
+ to_server = None
+ if res.server is None:
+ to_server = self.srv
+ res.server = self.srv.id
+ elif isinstance(res.server, str) and res.server in context.servers:
+ to_server = context.servers[res.server]
+
+ if to_server is None:
logger.error("The server defined in this response doesn't "
"exist: %s", res.server)
- res.server = None
- if res.server is None:
- res.server = self.srv
+ return False
# Sent the message only if treat_post authorize it
if context.treat_post(res):
- res.server.send_response(res, self.data)
+ if type(res.channel) != list:
+ res.channel = [ res.channel ]
+ for channel in res.channel:
+ if channel != to_server.nick:
+ to_server.write("%s %s :%s" % ("PRIVMSG", channel, res.get_message()))
+ else:
+ channel = res.sender
+ to_server.write("%s %s :%s" % ("NOTICE" if res.is_ctcp else "PRIVMSG", channel, res.get_message()))
elif res is not None:
logger.error("Unrecognized response type: %s", res)
@@ -98,7 +105,7 @@ class MessageConsumer:
self.treat_out(context, res)
# Inform that the message has been treated
- self.srv.msg_treated(self.data)
+ #self.srv.msg_treated(self.data)
diff --git a/nemubot.py b/nemubot.py
index 9d17ef8..f04899c 100755
--- a/nemubot.py
+++ b/nemubot.py
@@ -68,6 +68,7 @@ if __name__ == "__main__":
print ("Nemubot v%s ready, my PID is %i!" % (bot.__version__,
os.getpid()))
+ context.start()
while prmpt.run(context):
try:
# Reload context
diff --git a/prompt/builtins.py b/prompt/builtins.py
index 7633e46..f23400d 100644
--- a/prompt/builtins.py
+++ b/prompt/builtins.py
@@ -69,8 +69,7 @@ def load_file(filename, context):
nick = server["nick"] if server.hasAttribute("nick") else config["nick"]
owner = server["owner"] if server.hasAttribute("owner") else config["owner"]
realname = server["realname"] if server.hasAttribute("realname") else config["realname"]
- if context.add_server(server, nick, owner, realname,
- server.hasAttribute("ssl")):
+ if context.add_server(server, nick, owner, realname):
print("Server `%s:%s' successfully added." %
(server["server"], server["port"]))
else:
diff --git a/response.py b/response.py
index 4533c86..781c4d8 100644
--- a/response.py
+++ b/response.py
@@ -42,6 +42,13 @@ class Response:
self.set_sender(sender)
self.count = count
+ @property
+ def receivers(self):
+ if type(self.channel) is list:
+ return self.channel
+ else:
+ return [ self.channel ]
+
@property
def content(self):
#FIXME: error when messages in self.messages are list!
diff --git a/server/IRC.py b/server/IRC.py
index 16bcabd..f05c2c8 100644
--- a/server/IRC.py
+++ b/server/IRC.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
-# Nemubot is a modulable IRC bot, built around XML configuration files.
-# Copyright (C) 2012 Mercier Pierre-Olivier
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2014 nemunaire
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -16,272 +16,32 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import errno
-import os
-import ssl
-import socket
-import threading
-import traceback
-
-from channel import Channel
-from server.DCC import DCC
-from hooks import Hook
-import message
import server
-import xmlparser
+from server.socket import SocketServer
-class IRCServer(server.Server):
- """Class to interact with an IRC server"""
+class IRCServer(SocketServer):
- def __init__(self, node, nick, owner, realname, ssl=False):
- """Initialize an IRC server
-
- Arguments:
- node -- server node from XML configuration
- nick -- nick used by the bot on this server
- owner -- nick used by the bot owner on this server
- realname -- string used as realname on this server
- ssl -- require SSL?
- """
- self.node = node
-
- server.Server.__init__(self)
+ def __init__(self, node, nick, owner, realname):
+ SocketServer.__init__(self,
+ node["host"],
+ node["port"],
+ node["password"],
+ node.hasAttribute("ssl") and node["ssl"].lower() == "true")
self.nick = nick
self.owner = owner
self.realname = realname
- self.ssl = ssl
+ self.id = "TODO"
- # Listen private messages?
- self.listen_nick = True
-
- self.dcc_clients = dict()
-
- self.channels = dict()
- for chn in self.node.getNodes("channel"):
- chan = Channel(chn["name"], chn["password"])
- self.channels[chan.name] = chan
-
-
- @property
- def host(self):
- """Return the server hostname"""
- if self.node is not None and self.node.hasAttribute("server"):
- return self.node["server"]
- else:
- return "localhost"
-
- @property
- def port(self):
- """Return the connection port used on this server"""
- if self.node is not None and self.node.hasAttribute("port"):
- return self.node.getInt("port")
- else:
- return "6667"
-
- @property
- def password(self):
- """Return the password used to connect to this server"""
- if self.node is not None and self.node.hasAttribute("password"):
- return self.node["password"]
- else:
- return None
-
- @property
- def allow_all(self):
- """If True, treat message from all channels, not only listed one"""
- return (self.node is not None and self.node.hasAttribute("allowall")
- and self.node["allowall"] == "true")
-
- @property
- def autoconnect(self):
- """Autoconnect the server when added"""
- if self.node is not None and self.node.hasAttribute("autoconnect"):
- value = self.node["autoconnect"].lower()
- return value != "no" and value != "off" and value != "false"
- else:
- return False
-
- @property
- def id(self):
- """Gives the server identifiant"""
- return self.host + ":" + str(self.port)
-
- def register_hooks(self):
- self.add_hook(Hook(self.evt_channel, "JOIN"))
- self.add_hook(Hook(self.evt_channel, "PART"))
- self.add_hook(Hook(self.evt_server, "NICK"))
- self.add_hook(Hook(self.evt_server, "QUIT"))
- self.add_hook(Hook(self.evt_channel, "332"))
- self.add_hook(Hook(self.evt_channel, "353"))
-
- def evt_server(self, msg, srv):
- for chan in self.channels:
- self.channels[chan].treat(msg.cmd, msg)
-
- def evt_channel(self, msg, srv):
- if msg.receivers is not None:
- for receiver in msg.receivers:
- if receiver in self.channels:
- self.channels[receiver].treat(msg.cmd, msg)
-
- def accepted_channel(self, chan, sender=None):
- """Return True if the channel (or the user) is authorized"""
- return (self.allow_all or
- (chan in self.channels and (sender is None or sender in self.channels[chan].people)) or
- (self.listen_nick and chan == self.nick))
-
- def join(self, chan, password=None, force=False):
- """Join a channel"""
- if force or (chan is not None and
- self.connected and chan not in self.channels):
- self.channels[chan] = Channel(chan, password)
- if password is not None:
- self.s.send(("JOIN %s %s\r\n" % (chan, password)).encode())
- else:
- self.s.send(("JOIN %s\r\n" % chan).encode())
+ def _open(self):
+ if SocketServer._open(self):
+ if self.password is not None:
+ self.write("PASS :" + self.password)
+ self.write("NICK :" + self.nick)
+ self.write("USER %s %s bla :%s" % (self.nick, self.host, self.realname))
return True
- else:
- return False
+ return False
- def leave(self, chan):
- """Leave a channel"""
- if chan is not None and self.connected and chan in self.channels:
- if isinstance(chan, list):
- for c in chan:
- self.leave(c)
- else:
- self.s.send(("PART %s\r\n" % self.channels[chan].name).encode())
- del self.channels[chan]
- return True
- else:
- return False
-
-# Main loop
- def run(self):
- if not self.connected:
- self.s = socket.socket() #Create the socket
- if self.ssl:
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- self.s = ctx.wrap_socket(self.s)
- try:
- self.s.connect((self.host, self.port)) #Connect to server
- except socket.error as e:
- self.s = None
- self.logger.critical("Unable to connect to %s:%d: %s",
- self.host, self.port, os.strerror(e.errno))
- return
- self.stopping.clear()
-
- if self.password != None:
- self.s.send(b"PASS " + self.password.encode () + b"\r\n")
- self.s.send(("NICK %s\r\n" % self.nick).encode ())
- self.s.send(("USER %s %s bla :%s\r\n" % (self.nick, self.host,
- self.realname)).encode())
- raw = self.s.recv(1024)
- if not raw:
- self.logger.critical("Unable to connect to %s:%d", self.host, self.port)
- return
- self.connected = True
- self.logger.info("Connection to %s:%d completed", self.host, self.port)
-
- if len(self.channels) > 0:
- for chn in self.channels.keys():
- self.join(self.channels[chn].name,
- self.channels[chn].password, force=True)
-
-
- readbuffer = b'' #Here we store all the messages from server
- try:
- while not self.stop:
- readbuffer = readbuffer + raw
- temp = readbuffer.split(b'\n')
- readbuffer = temp.pop()
-
- for line in temp:
- self.treat_msg(line)
- raw = self.s.recv(1024) #recieve server messages
- except socket.error:
- pass
-
- if self.connected:
- self.s.close()
- self.connected = False
- if self.closing_event is not None:
- self.closing_event()
- self.logger.info("Server `%s' successfully stopped.", self.id)
- self.stopping.set()
- # Rearm Thread
- threading.Thread.__init__(self)
-
-
-# Overwritted methods
-
- def disconnect(self):
- """Close the socket with the server and all DCC client connections"""
- #Close all DCC connection
- clts = [c for c in self.dcc_clients]
- for clt in clts:
- self.dcc_clients[clt].disconnect()
- return server.Server.disconnect(self)
-
-
-
-# Abstract methods
-
- def send_pong(self, cnt):
- """Send a PONG command to the server with argument cnt"""
- self.s.send(("PONG %s\r\n" % cnt).encode())
-
- def msg_treated(self, origin):
- """Do nothing; here for implement abstract class"""
- pass
-
- def send_dcc(self, msg, to):
- """Send a message through DCC connection"""
- if msg is not None and to is not None:
- realname = to.split("!")[1]
- if realname not in self.dcc_clients.keys():
- d = DCC(self, to)
- self.dcc_clients[realname] = d
- self.dcc_clients[realname].send_dcc(msg)
-
- def send_msg_final(self, channel, line, cmd="PRIVMSG", endl="\r\n"):
- """Send a message without checks or format"""
- #TODO: add something for post message treatment here
- if channel == self.nick:
- self.logger.warn("Nemubot talks to himself: %s", line, stack_info=True)
- if line is not None and channel is not None:
- if self.s is None:
- self.logger.warn("Attempt to send message on a non connected server: %s: %s", self.id, line, stack_info=True)
- elif len(line) < 442:
- self.s.send(("%s %s :%s%s" % (cmd, channel, line, endl)).encode ())
- else:
- self.logger.warn("Message truncated due to size (%d ; max : 442) : %s", len(line), line, stack_info=True)
- self.s.send (("%s %s :%s%s" % (cmd, channel, line[0:442]+"<…>", endl)).encode ())
-
- def send_msg_usr(self, user, msg):
- """Send a message to a user instead of a channel"""
- if user is not None and user[0] != "#":
- realname = user.split("!")[1]
- if realname in self.dcc_clients or user in self.dcc_clients:
- self.send_dcc(msg, user)
- else:
- for line in msg.split("\n"):
- if line != "":
- self.send_msg_final(user.split('!')[0], msg)
-
- def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
- """Send a message to a channel"""
- if self.accepted_channel(channel):
- server.Server.send_msg(self, channel, msg, cmd, endl)
-
- def send_msg_verified(self, sender, channel, msg, cmd = "PRIVMSG", endl = "\r\n"):
- """Send a message to a channel, only if the source user is on this channel too"""
- if self.accepted_channel(channel, sender):
- self.send_msg_final(channel, msg, cmd, endl)
-
- def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"):
- """Send a message to all channels on this server"""
- for channel in self.channels.keys():
- self.send_msg(channel, msg, cmd, endl)
+ def _close(self):
+ self.write("QUIT")
+ SocketServer._close(self)
diff --git a/server/__init__.py b/server/__init__.py
index 8fffe1d..9f423c2 100644
--- a/server/__init__.py
+++ b/server/__init__.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
-# Nemubot is a modulable IRC bot, built around XML configuration files.
-# Copyright (C) 2012 Mercier Pierre-Olivier
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2014 nemunaire
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -16,156 +16,70 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import io
import logging
import socket
-import threading
+import queue
-class Server(threading.Thread):
- def __init__(self, socket = None):
- self.stop = False
- self.stopping = threading.Event()
- self.s = socket
- self.connected = self.s is not None
- self.closing_event = None
+# Lists for select
+_rlist = []
+_wlist = []
+_xlist = []
- self.moremessages = dict()
+# Extends from IOBase in order to be compatible with select function
+class AbstractServer(io.IOBase):
- self.logger = logging.getLogger("nemubot.server." + self.id)
+ """An abstract server: handle communication with an IM server"""
- threading.Thread.__init__(self)
+ def __init__(self, send_callback=None):
+ """Initialize an abstract server
- def isDCC(self, to=None):
- return to is not None and to in self.dcc_clients
+ Keyword argument:
+ send_callback -- Callback when developper want to send a message
+ """
- @property
- def ip(self):
- """Convert common IP representation to little-endian integer representation"""
- sum = 0
- if self.node.hasAttribute("ip"):
- ip = self.node["ip"]
+ self.logger = logging.getLogger("nemubot.server.TODO")
+ self._sending_queue = queue.Queue()
+ if send_callback is not None:
+ self._send_callback = send_callback
else:
- #TODO: find the external IP
- ip = "0.0.0.0"
- for b in ip.split("."):
- sum = 256 * sum + int(b)
- return sum
+ self._send_callback = self._write_select
- def toIP(self, input):
- """Convert little-endian int to IPv4 adress"""
- ip = ""
- for i in range(0,4):
- mod = input % 256
- ip = "%d.%s" % (mod, ip)
- input = (input - mod) / 256
- return ip[:len(ip) - 1]
- @property
- def id(self):
- """Gives the server identifiant"""
- raise NotImplemented()
+ def open(self):
+ """Generic open function that register the server un _rlist in case of successful _open"""
+ if self._open():
+ _rlist.append(self)
- def accepted_channel(self, msg, sender=None):
- return True
- def msg_treated(self, origin):
- """Action done on server when a message was treated"""
- raise NotImplemented()
+ def close(self):
+ """Generic close function that register the server un _{r,w,x}list in case of successful _close"""
+ if self._close():
+ if self in _rlist:
+ _rlist.remove(self)
+ if self in _wlist:
+ _wlist.remove(self)
+ if self in _xlist:
+ _xlist.remove(self)
- def send_response(self, res, origin):
- """Analyse a Response and send it"""
- if type(res.channel) != list:
- res.channel = [ res.channel ]
- for channel in res.channel:
- if channel != self.nick:
- self.send_msg(channel, res.get_message())
- else:
- channel = res.sender
- self.send_msg_usr(channel, res.get_message(), "NOTICE" if res.is_ctcp else "PRIVMSG")
+ def write(self, message):
+ """Send a message to the server using send_callback"""
+ self._send_callback(message)
- if not res.alone:
- if hasattr(self, "send_bot"):
- self.send_bot("NOMORE %s" % res.channel)
- self.moremessages[channel] = res
+ def write_select(self):
+ """Internal function used by the select function"""
+ try:
+ while not self._sending_queue.empty():
+ self._write(self._sending_queue.get_nowait())
+ _wlist.remove(self)
- def send_ctcp(self, to, msg, cmd="NOTICE", endl="\r\n"):
- """Send a message as CTCP response"""
- if msg is not None and to is not None:
- for line in msg.split("\n"):
- if line != "":
- self.send_msg_final(to.split("!")[0], "\x01" + line + "\x01", cmd, endl)
+ except queue.Empty:
+ pass
- def send_dcc(self, msg, to):
- """Send a message through DCC connection"""
- raise NotImplemented()
-
- def send_msg_final(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
- """Send a message without checks or format"""
- raise NotImplemented()
-
- def send_msg_usr(self, user, msg):
- """Send a message to a user instead of a channel"""
- raise NotImplemented()
-
- def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
- """Send a message to a channel"""
- if msg is not None:
- for line in msg.split("\n"):
- if line != "":
- self.send_msg_final(channel, line, cmd, endl)
-
- def send_msg_verified(self, sender, channel, msg, cmd="PRIVMSG", endl="\r\n"):
- """A more secure way to send messages"""
- raise NotImplemented()
-
- def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"):
- """Send a message to all channels on this server"""
- raise NotImplemented()
-
- def disconnect(self):
- """Close the socket with the server"""
- if self.connected:
- self.stop = True
- try:
- self.s.shutdown(socket.SHUT_RDWR)
- except socket.error:
- pass
-
- self.stopping.wait()
- return True
- else:
- return False
-
- def kill(self):
- """Just stop the main loop, don't close the socket directly"""
- if self.connected:
- self.stop = True
- self.connected = False
- #Send a message in order to close the socket
- try:
- self.s.send(("Bye!\r\n").encode ())
- except:
- pass
- self.stopping.wait()
- return True
- else:
- return False
-
- def launch(self, receive_action, verb=True):
- """Connect to the server if it is no yet connected"""
- self._receive_action = receive_action
- if not self.connected:
- self.stop = False
- self.logger.info("Entering main loop for server")
- try:
- self.start()
- except RuntimeError:
- pass
- elif verb:
- print (" Already connected.")
-
- def treat_msg(self, line, private=False):
- self._receive_action(self, line, private)
-
- def run(self):
- raise NotImplemented()
+ def _write_select(self, message):
+ """Send a message to the server safely through select"""
+ self._sending_queue.put(self.format(message))
+ self.logger.debug("Message '%s' appended to Queue", message)
+ if self not in _wlist:
+ _wlist.append(self)
diff --git a/server/socket.py b/server/socket.py
new file mode 100644
index 0000000..21d3d0c
--- /dev/null
+++ b/server/socket.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2014 nemunaire
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import ssl
+import socket
+
+from server import AbstractServer
+
+class SocketServer(AbstractServer):
+
+ def __init__(self, host, port=6667, password=None, ssl=False):
+ AbstractServer.__init__(self)
+ self.host = host
+ self.port = int(port)
+ self.password = password
+ self.ssl = ssl
+
+ self.socket = None
+ self.readbuffer = b''
+
+ def fileno(self):
+ return self.socket.fileno() if self.socket else None
+
+ def _open(self):
+ # Create the socket
+ self.socket = socket.socket()
+
+ # Wrap the socket for SSL
+ if self.ssl:
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ self.socket = ctx.wrap_socket(self.socket)
+
+ try:
+ self.socket.connect((self.host, self.port)) #Connect to server
+ self.logger.info("Connected to %s:%d", self.host, self.port)
+ except socket.error as e:
+ self.socket = None
+ self.logger.critical("Unable to connect to %s:%d: %s",
+ self.host, self.port, os.strerror(e.errno))
+ return False
+
+ return True
+
+ def _close(self):
+ if self.socket is not None:
+ self.socket.shutdown(SHUT_RDWR)
+ self.socket.close()
+ self.socket = None
+
+ def _write(self, cnt):
+ self.socket.send(cnt)
+
+ def format(self, txt):
+ return txt.encode() + b'\r\n'
+
+ def read(self):
+ raw = self.socket.recv(1024)
+ temp = (self.readbuffer + raw).split(b'\r\n')
+ self.readbuffer = temp.pop()
+
+ for line in temp:
+ yield line