(wip) use select instead of a thread by server. Currently read and write seems to work properly, we lost some code (like more, next, ...) that never should be in server part
This commit is contained in:
parent
28c1ad088b
commit
81593a493b
33
bot.py
33
bot.py
@ -21,6 +21,7 @@ from datetime import timedelta
|
|||||||
import logging
|
import logging
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import re
|
import re
|
||||||
|
from select import select
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@ -38,7 +39,7 @@ import response
|
|||||||
|
|
||||||
logger = logging.getLogger("nemubot.bot")
|
logger = logging.getLogger("nemubot.bot")
|
||||||
|
|
||||||
class Bot:
|
class Bot(threading.Thread):
|
||||||
|
|
||||||
"""Class containing the bot context and ensuring key goals"""
|
"""Class containing the bot context and ensuring key goals"""
|
||||||
|
|
||||||
@ -51,6 +52,8 @@ class Bot:
|
|||||||
data_path -- Path to directory where store bot context data
|
data_path -- Path to directory where store bot context data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
|
||||||
logger.info("Initiate nemubot v%s", __version__)
|
logger.info("Initiate nemubot v%s", __version__)
|
||||||
|
|
||||||
# External IP for accessing this bot
|
# External IP for accessing this bot
|
||||||
@ -90,6 +93,22 @@ class Bot:
|
|||||||
self)
|
self)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
from server import _rlist, _wlist, _xlist
|
||||||
|
|
||||||
|
self.stop = False
|
||||||
|
while not self.stop:
|
||||||
|
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
|
||||||
|
|
||||||
|
for x in xl:
|
||||||
|
x.exception()
|
||||||
|
for w in wl:
|
||||||
|
w.write_select()
|
||||||
|
for r in rl:
|
||||||
|
for i in r.read():
|
||||||
|
self.receive_message(r, i)
|
||||||
|
|
||||||
|
|
||||||
def init_ctcp_capabilities(self):
|
def init_ctcp_capabilities(self):
|
||||||
"""Reset existing CTCP capabilities to default one"""
|
"""Reset existing CTCP capabilities to default one"""
|
||||||
|
|
||||||
@ -277,17 +296,16 @@ class Bot:
|
|||||||
c.start()
|
c.start()
|
||||||
|
|
||||||
|
|
||||||
def add_server(self, node, nick, owner, realname, ssl=False):
|
def add_server(self, node, nick, owner, realname):
|
||||||
"""Add a new server to the context"""
|
"""Add a new server to the context"""
|
||||||
srv = IRCServer(node, nick, owner, realname, ssl)
|
srv = IRCServer(node, nick, owner, realname)
|
||||||
srv.add_hook = lambda h: self.hooks.add_hook("irc_hook", h, self)
|
srv.add_hook = lambda h: self.hooks.add_hook("irc_hook", h, self)
|
||||||
srv.add_networkbot = self.add_networkbot
|
srv.add_networkbot = self.add_networkbot
|
||||||
srv.send_bot = lambda d: self.send_networkbot(srv, d)
|
srv.send_bot = lambda d: self.send_networkbot(srv, d)
|
||||||
srv.register_hooks()
|
#srv.register_hooks()
|
||||||
if srv.id not in self.servers:
|
if srv.id not in self.servers:
|
||||||
self.servers[srv.id] = srv
|
self.servers[srv.id] = srv
|
||||||
if srv.autoconnect:
|
srv.open()
|
||||||
srv.launch(self.receive_message)
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@ -378,6 +396,9 @@ class Bot:
|
|||||||
for srv in k:
|
for srv in k:
|
||||||
self.servers[srv].disconnect()
|
self.servers[srv].disconnect()
|
||||||
|
|
||||||
|
self.stop = True
|
||||||
|
|
||||||
|
|
||||||
# Hooks cache
|
# Hooks cache
|
||||||
|
|
||||||
def create_cache(self, name):
|
def create_cache(self, name):
|
||||||
|
31
consumer.py
31
consumer.py
@ -46,9 +46,8 @@ class MessageConsumer:
|
|||||||
def treat_in(self, context, msg):
|
def treat_in(self, context, msg):
|
||||||
"""Treat the input message"""
|
"""Treat the input message"""
|
||||||
if msg.cmd == "PING":
|
if msg.cmd == "PING":
|
||||||
self.srv.send_pong(msg.params[0])
|
self.srv.write("%s :%s" % ("PONG", msg.params[0]))
|
||||||
elif hasattr(msg, "receivers"):
|
elif hasattr(msg, "receivers"):
|
||||||
msg.receivers = [ receiver for receiver in msg.receivers if self.srv.accepted_channel(receiver) ]
|
|
||||||
if msg.receivers:
|
if msg.receivers:
|
||||||
# All messages
|
# All messages
|
||||||
context.treat_pre(msg, self.srv)
|
context.treat_pre(msg, self.srv)
|
||||||
@ -63,20 +62,28 @@ class MessageConsumer:
|
|||||||
|
|
||||||
elif isinstance(res, response.Response):
|
elif isinstance(res, response.Response):
|
||||||
# Define the destination server
|
# Define the destination server
|
||||||
if (res.server is not None and
|
to_server = None
|
||||||
isinstance(res.server, str) and res.server in context.servers):
|
if res.server is None:
|
||||||
res.server = context.servers[res.server]
|
to_server = self.srv
|
||||||
if (res.server is not None and
|
res.server = self.srv.id
|
||||||
not isinstance(res.server, server.Server)):
|
elif isinstance(res.server, str) and res.server in context.servers:
|
||||||
|
to_server = context.servers[res.server]
|
||||||
|
|
||||||
|
if to_server is None:
|
||||||
logger.error("The server defined in this response doesn't "
|
logger.error("The server defined in this response doesn't "
|
||||||
"exist: %s", res.server)
|
"exist: %s", res.server)
|
||||||
res.server = None
|
return False
|
||||||
if res.server is None:
|
|
||||||
res.server = self.srv
|
|
||||||
|
|
||||||
# Sent the message only if treat_post authorize it
|
# Sent the message only if treat_post authorize it
|
||||||
if context.treat_post(res):
|
if context.treat_post(res):
|
||||||
res.server.send_response(res, self.data)
|
if type(res.channel) != list:
|
||||||
|
res.channel = [ res.channel ]
|
||||||
|
for channel in res.channel:
|
||||||
|
if channel != to_server.nick:
|
||||||
|
to_server.write("%s %s :%s" % ("PRIVMSG", channel, res.get_message()))
|
||||||
|
else:
|
||||||
|
channel = res.sender
|
||||||
|
to_server.write("%s %s :%s" % ("NOTICE" if res.is_ctcp else "PRIVMSG", channel, res.get_message()))
|
||||||
|
|
||||||
elif res is not None:
|
elif res is not None:
|
||||||
logger.error("Unrecognized response type: %s", res)
|
logger.error("Unrecognized response type: %s", res)
|
||||||
@ -98,7 +105,7 @@ class MessageConsumer:
|
|||||||
self.treat_out(context, res)
|
self.treat_out(context, res)
|
||||||
|
|
||||||
# Inform that the message has been treated
|
# Inform that the message has been treated
|
||||||
self.srv.msg_treated(self.data)
|
#self.srv.msg_treated(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print ("Nemubot v%s ready, my PID is %i!" % (bot.__version__,
|
print ("Nemubot v%s ready, my PID is %i!" % (bot.__version__,
|
||||||
os.getpid()))
|
os.getpid()))
|
||||||
|
context.start()
|
||||||
while prmpt.run(context):
|
while prmpt.run(context):
|
||||||
try:
|
try:
|
||||||
# Reload context
|
# Reload context
|
||||||
|
@ -69,8 +69,7 @@ def load_file(filename, context):
|
|||||||
nick = server["nick"] if server.hasAttribute("nick") else config["nick"]
|
nick = server["nick"] if server.hasAttribute("nick") else config["nick"]
|
||||||
owner = server["owner"] if server.hasAttribute("owner") else config["owner"]
|
owner = server["owner"] if server.hasAttribute("owner") else config["owner"]
|
||||||
realname = server["realname"] if server.hasAttribute("realname") else config["realname"]
|
realname = server["realname"] if server.hasAttribute("realname") else config["realname"]
|
||||||
if context.add_server(server, nick, owner, realname,
|
if context.add_server(server, nick, owner, realname):
|
||||||
server.hasAttribute("ssl")):
|
|
||||||
print("Server `%s:%s' successfully added." %
|
print("Server `%s:%s' successfully added." %
|
||||||
(server["server"], server["port"]))
|
(server["server"], server["port"]))
|
||||||
else:
|
else:
|
||||||
|
@ -42,6 +42,13 @@ class Response:
|
|||||||
self.set_sender(sender)
|
self.set_sender(sender)
|
||||||
self.count = count
|
self.count = count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def receivers(self):
|
||||||
|
if type(self.channel) is list:
|
||||||
|
return self.channel
|
||||||
|
else:
|
||||||
|
return [ self.channel ]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self):
|
def content(self):
|
||||||
#FIXME: error when messages in self.messages are list!
|
#FIXME: error when messages in self.messages are list!
|
||||||
|
280
server/IRC.py
280
server/IRC.py
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Nemubot is a modulable IRC bot, built around XML configuration files.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2014 nemunaire
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -16,272 +16,32 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import errno
|
|
||||||
import os
|
|
||||||
import ssl
|
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from channel import Channel
|
|
||||||
from server.DCC import DCC
|
|
||||||
from hooks import Hook
|
|
||||||
import message
|
|
||||||
import server
|
import server
|
||||||
import xmlparser
|
from server.socket import SocketServer
|
||||||
|
|
||||||
class IRCServer(server.Server):
|
class IRCServer(SocketServer):
|
||||||
"""Class to interact with an IRC server"""
|
|
||||||
|
|
||||||
def __init__(self, node, nick, owner, realname, ssl=False):
|
def __init__(self, node, nick, owner, realname):
|
||||||
"""Initialize an IRC server
|
SocketServer.__init__(self,
|
||||||
|
node["host"],
|
||||||
Arguments:
|
node["port"],
|
||||||
node -- server node from XML configuration
|
node["password"],
|
||||||
nick -- nick used by the bot on this server
|
node.hasAttribute("ssl") and node["ssl"].lower() == "true")
|
||||||
owner -- nick used by the bot owner on this server
|
|
||||||
realname -- string used as realname on this server
|
|
||||||
ssl -- require SSL?
|
|
||||||
"""
|
|
||||||
self.node = node
|
|
||||||
|
|
||||||
server.Server.__init__(self)
|
|
||||||
|
|
||||||
self.nick = nick
|
self.nick = nick
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.realname = realname
|
self.realname = realname
|
||||||
self.ssl = ssl
|
self.id = "TODO"
|
||||||
|
|
||||||
# Listen private messages?
|
def _open(self):
|
||||||
self.listen_nick = True
|
if SocketServer._open(self):
|
||||||
|
if self.password is not None:
|
||||||
self.dcc_clients = dict()
|
self.write("PASS :" + self.password)
|
||||||
|
self.write("NICK :" + self.nick)
|
||||||
self.channels = dict()
|
self.write("USER %s %s bla :%s" % (self.nick, self.host, self.realname))
|
||||||
for chn in self.node.getNodes("channel"):
|
|
||||||
chan = Channel(chn["name"], chn["password"])
|
|
||||||
self.channels[chan.name] = chan
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def host(self):
|
|
||||||
"""Return the server hostname"""
|
|
||||||
if self.node is not None and self.node.hasAttribute("server"):
|
|
||||||
return self.node["server"]
|
|
||||||
else:
|
|
||||||
return "localhost"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def port(self):
|
|
||||||
"""Return the connection port used on this server"""
|
|
||||||
if self.node is not None and self.node.hasAttribute("port"):
|
|
||||||
return self.node.getInt("port")
|
|
||||||
else:
|
|
||||||
return "6667"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def password(self):
|
|
||||||
"""Return the password used to connect to this server"""
|
|
||||||
if self.node is not None and self.node.hasAttribute("password"):
|
|
||||||
return self.node["password"]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allow_all(self):
|
|
||||||
"""If True, treat message from all channels, not only listed one"""
|
|
||||||
return (self.node is not None and self.node.hasAttribute("allowall")
|
|
||||||
and self.node["allowall"] == "true")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def autoconnect(self):
|
|
||||||
"""Autoconnect the server when added"""
|
|
||||||
if self.node is not None and self.node.hasAttribute("autoconnect"):
|
|
||||||
value = self.node["autoconnect"].lower()
|
|
||||||
return value != "no" and value != "off" and value != "false"
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self):
|
|
||||||
"""Gives the server identifiant"""
|
|
||||||
return self.host + ":" + str(self.port)
|
|
||||||
|
|
||||||
def register_hooks(self):
|
|
||||||
self.add_hook(Hook(self.evt_channel, "JOIN"))
|
|
||||||
self.add_hook(Hook(self.evt_channel, "PART"))
|
|
||||||
self.add_hook(Hook(self.evt_server, "NICK"))
|
|
||||||
self.add_hook(Hook(self.evt_server, "QUIT"))
|
|
||||||
self.add_hook(Hook(self.evt_channel, "332"))
|
|
||||||
self.add_hook(Hook(self.evt_channel, "353"))
|
|
||||||
|
|
||||||
def evt_server(self, msg, srv):
|
|
||||||
for chan in self.channels:
|
|
||||||
self.channels[chan].treat(msg.cmd, msg)
|
|
||||||
|
|
||||||
def evt_channel(self, msg, srv):
|
|
||||||
if msg.receivers is not None:
|
|
||||||
for receiver in msg.receivers:
|
|
||||||
if receiver in self.channels:
|
|
||||||
self.channels[receiver].treat(msg.cmd, msg)
|
|
||||||
|
|
||||||
def accepted_channel(self, chan, sender=None):
|
|
||||||
"""Return True if the channel (or the user) is authorized"""
|
|
||||||
return (self.allow_all or
|
|
||||||
(chan in self.channels and (sender is None or sender in self.channels[chan].people)) or
|
|
||||||
(self.listen_nick and chan == self.nick))
|
|
||||||
|
|
||||||
def join(self, chan, password=None, force=False):
|
|
||||||
"""Join a channel"""
|
|
||||||
if force or (chan is not None and
|
|
||||||
self.connected and chan not in self.channels):
|
|
||||||
self.channels[chan] = Channel(chan, password)
|
|
||||||
if password is not None:
|
|
||||||
self.s.send(("JOIN %s %s\r\n" % (chan, password)).encode())
|
|
||||||
else:
|
|
||||||
self.s.send(("JOIN %s\r\n" % chan).encode())
|
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def leave(self, chan):
|
def _close(self):
|
||||||
"""Leave a channel"""
|
self.write("QUIT")
|
||||||
if chan is not None and self.connected and chan in self.channels:
|
SocketServer._close(self)
|
||||||
if isinstance(chan, list):
|
|
||||||
for c in chan:
|
|
||||||
self.leave(c)
|
|
||||||
else:
|
|
||||||
self.s.send(("PART %s\r\n" % self.channels[chan].name).encode())
|
|
||||||
del self.channels[chan]
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Main loop
|
|
||||||
def run(self):
|
|
||||||
if not self.connected:
|
|
||||||
self.s = socket.socket() #Create the socket
|
|
||||||
if self.ssl:
|
|
||||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
|
||||||
self.s = ctx.wrap_socket(self.s)
|
|
||||||
try:
|
|
||||||
self.s.connect((self.host, self.port)) #Connect to server
|
|
||||||
except socket.error as e:
|
|
||||||
self.s = None
|
|
||||||
self.logger.critical("Unable to connect to %s:%d: %s",
|
|
||||||
self.host, self.port, os.strerror(e.errno))
|
|
||||||
return
|
|
||||||
self.stopping.clear()
|
|
||||||
|
|
||||||
if self.password != None:
|
|
||||||
self.s.send(b"PASS " + self.password.encode () + b"\r\n")
|
|
||||||
self.s.send(("NICK %s\r\n" % self.nick).encode ())
|
|
||||||
self.s.send(("USER %s %s bla :%s\r\n" % (self.nick, self.host,
|
|
||||||
self.realname)).encode())
|
|
||||||
raw = self.s.recv(1024)
|
|
||||||
if not raw:
|
|
||||||
self.logger.critical("Unable to connect to %s:%d", self.host, self.port)
|
|
||||||
return
|
|
||||||
self.connected = True
|
|
||||||
self.logger.info("Connection to %s:%d completed", self.host, self.port)
|
|
||||||
|
|
||||||
if len(self.channels) > 0:
|
|
||||||
for chn in self.channels.keys():
|
|
||||||
self.join(self.channels[chn].name,
|
|
||||||
self.channels[chn].password, force=True)
|
|
||||||
|
|
||||||
|
|
||||||
readbuffer = b'' #Here we store all the messages from server
|
|
||||||
try:
|
|
||||||
while not self.stop:
|
|
||||||
readbuffer = readbuffer + raw
|
|
||||||
temp = readbuffer.split(b'\n')
|
|
||||||
readbuffer = temp.pop()
|
|
||||||
|
|
||||||
for line in temp:
|
|
||||||
self.treat_msg(line)
|
|
||||||
raw = self.s.recv(1024) #recieve server messages
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.connected:
|
|
||||||
self.s.close()
|
|
||||||
self.connected = False
|
|
||||||
if self.closing_event is not None:
|
|
||||||
self.closing_event()
|
|
||||||
self.logger.info("Server `%s' successfully stopped.", self.id)
|
|
||||||
self.stopping.set()
|
|
||||||
# Rearm Thread
|
|
||||||
threading.Thread.__init__(self)
|
|
||||||
|
|
||||||
|
|
||||||
# Overwritted methods
|
|
||||||
|
|
||||||
def disconnect(self):
|
|
||||||
"""Close the socket with the server and all DCC client connections"""
|
|
||||||
#Close all DCC connection
|
|
||||||
clts = [c for c in self.dcc_clients]
|
|
||||||
for clt in clts:
|
|
||||||
self.dcc_clients[clt].disconnect()
|
|
||||||
return server.Server.disconnect(self)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Abstract methods
|
|
||||||
|
|
||||||
def send_pong(self, cnt):
|
|
||||||
"""Send a PONG command to the server with argument cnt"""
|
|
||||||
self.s.send(("PONG %s\r\n" % cnt).encode())
|
|
||||||
|
|
||||||
def msg_treated(self, origin):
|
|
||||||
"""Do nothing; here for implement abstract class"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def send_dcc(self, msg, to):
|
|
||||||
"""Send a message through DCC connection"""
|
|
||||||
if msg is not None and to is not None:
|
|
||||||
realname = to.split("!")[1]
|
|
||||||
if realname not in self.dcc_clients.keys():
|
|
||||||
d = DCC(self, to)
|
|
||||||
self.dcc_clients[realname] = d
|
|
||||||
self.dcc_clients[realname].send_dcc(msg)
|
|
||||||
|
|
||||||
def send_msg_final(self, channel, line, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""Send a message without checks or format"""
|
|
||||||
#TODO: add something for post message treatment here
|
|
||||||
if channel == self.nick:
|
|
||||||
self.logger.warn("Nemubot talks to himself: %s", line, stack_info=True)
|
|
||||||
if line is not None and channel is not None:
|
|
||||||
if self.s is None:
|
|
||||||
self.logger.warn("Attempt to send message on a non connected server: %s: %s", self.id, line, stack_info=True)
|
|
||||||
elif len(line) < 442:
|
|
||||||
self.s.send(("%s %s :%s%s" % (cmd, channel, line, endl)).encode ())
|
|
||||||
else:
|
|
||||||
self.logger.warn("Message truncated due to size (%d ; max : 442) : %s", len(line), line, stack_info=True)
|
|
||||||
self.s.send (("%s %s :%s%s" % (cmd, channel, line[0:442]+"<…>", endl)).encode ())
|
|
||||||
|
|
||||||
def send_msg_usr(self, user, msg):
|
|
||||||
"""Send a message to a user instead of a channel"""
|
|
||||||
if user is not None and user[0] != "#":
|
|
||||||
realname = user.split("!")[1]
|
|
||||||
if realname in self.dcc_clients or user in self.dcc_clients:
|
|
||||||
self.send_dcc(msg, user)
|
|
||||||
else:
|
|
||||||
for line in msg.split("\n"):
|
|
||||||
if line != "":
|
|
||||||
self.send_msg_final(user.split('!')[0], msg)
|
|
||||||
|
|
||||||
def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""Send a message to a channel"""
|
|
||||||
if self.accepted_channel(channel):
|
|
||||||
server.Server.send_msg(self, channel, msg, cmd, endl)
|
|
||||||
|
|
||||||
def send_msg_verified(self, sender, channel, msg, cmd = "PRIVMSG", endl = "\r\n"):
|
|
||||||
"""Send a message to a channel, only if the source user is on this channel too"""
|
|
||||||
if self.accepted_channel(channel, sender):
|
|
||||||
self.send_msg_final(channel, msg, cmd, endl)
|
|
||||||
|
|
||||||
def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""Send a message to all channels on this server"""
|
|
||||||
for channel in self.channels.keys():
|
|
||||||
self.send_msg(channel, msg, cmd, endl)
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Nemubot is a modulable IRC bot, built around XML configuration files.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2014 nemunaire
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -16,156 +16,70 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import threading
|
import queue
|
||||||
|
|
||||||
class Server(threading.Thread):
|
# Lists for select
|
||||||
def __init__(self, socket = None):
|
_rlist = []
|
||||||
self.stop = False
|
_wlist = []
|
||||||
self.stopping = threading.Event()
|
_xlist = []
|
||||||
self.s = socket
|
|
||||||
self.connected = self.s is not None
|
|
||||||
self.closing_event = None
|
|
||||||
|
|
||||||
self.moremessages = dict()
|
# Extends from IOBase in order to be compatible with select function
|
||||||
|
class AbstractServer(io.IOBase):
|
||||||
|
|
||||||
self.logger = logging.getLogger("nemubot.server." + self.id)
|
"""An abstract server: handle communication with an IM server"""
|
||||||
|
|
||||||
threading.Thread.__init__(self)
|
def __init__(self, send_callback=None):
|
||||||
|
"""Initialize an abstract server
|
||||||
|
|
||||||
def isDCC(self, to=None):
|
Keyword argument:
|
||||||
return to is not None and to in self.dcc_clients
|
send_callback -- Callback when developper want to send a message
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
self.logger = logging.getLogger("nemubot.server.TODO")
|
||||||
def ip(self):
|
self._sending_queue = queue.Queue()
|
||||||
"""Convert common IP representation to little-endian integer representation"""
|
if send_callback is not None:
|
||||||
sum = 0
|
self._send_callback = send_callback
|
||||||
if self.node.hasAttribute("ip"):
|
|
||||||
ip = self.node["ip"]
|
|
||||||
else:
|
else:
|
||||||
#TODO: find the external IP
|
self._send_callback = self._write_select
|
||||||
ip = "0.0.0.0"
|
|
||||||
for b in ip.split("."):
|
|
||||||
sum = 256 * sum + int(b)
|
|
||||||
return sum
|
|
||||||
|
|
||||||
def toIP(self, input):
|
|
||||||
"""Convert little-endian int to IPv4 adress"""
|
|
||||||
ip = ""
|
|
||||||
for i in range(0,4):
|
|
||||||
mod = input % 256
|
|
||||||
ip = "%d.%s" % (mod, ip)
|
|
||||||
input = (input - mod) / 256
|
|
||||||
return ip[:len(ip) - 1]
|
|
||||||
|
|
||||||
@property
|
def open(self):
|
||||||
def id(self):
|
"""Generic open function that register the server un _rlist in case of successful _open"""
|
||||||
"""Gives the server identifiant"""
|
if self._open():
|
||||||
raise NotImplemented()
|
_rlist.append(self)
|
||||||
|
|
||||||
def accepted_channel(self, msg, sender=None):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def msg_treated(self, origin):
|
def close(self):
|
||||||
"""Action done on server when a message was treated"""
|
"""Generic close function that register the server un _{r,w,x}list in case of successful _close"""
|
||||||
raise NotImplemented()
|
if self._close():
|
||||||
|
if self in _rlist:
|
||||||
|
_rlist.remove(self)
|
||||||
|
if self in _wlist:
|
||||||
|
_wlist.remove(self)
|
||||||
|
if self in _xlist:
|
||||||
|
_xlist.remove(self)
|
||||||
|
|
||||||
def send_response(self, res, origin):
|
|
||||||
"""Analyse a Response and send it"""
|
|
||||||
if type(res.channel) != list:
|
|
||||||
res.channel = [ res.channel ]
|
|
||||||
|
|
||||||
for channel in res.channel:
|
def write(self, message):
|
||||||
if channel != self.nick:
|
"""Send a message to the server using send_callback"""
|
||||||
self.send_msg(channel, res.get_message())
|
self._send_callback(message)
|
||||||
else:
|
|
||||||
channel = res.sender
|
|
||||||
self.send_msg_usr(channel, res.get_message(), "NOTICE" if res.is_ctcp else "PRIVMSG")
|
|
||||||
|
|
||||||
if not res.alone:
|
def write_select(self):
|
||||||
if hasattr(self, "send_bot"):
|
"""Internal function used by the select function"""
|
||||||
self.send_bot("NOMORE %s" % res.channel)
|
|
||||||
self.moremessages[channel] = res
|
|
||||||
|
|
||||||
def send_ctcp(self, to, msg, cmd="NOTICE", endl="\r\n"):
|
|
||||||
"""Send a message as CTCP response"""
|
|
||||||
if msg is not None and to is not None:
|
|
||||||
for line in msg.split("\n"):
|
|
||||||
if line != "":
|
|
||||||
self.send_msg_final(to.split("!")[0], "\x01" + line + "\x01", cmd, endl)
|
|
||||||
|
|
||||||
def send_dcc(self, msg, to):
|
|
||||||
"""Send a message through DCC connection"""
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
def send_msg_final(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""Send a message without checks or format"""
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
def send_msg_usr(self, user, msg):
|
|
||||||
"""Send a message to a user instead of a channel"""
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""Send a message to a channel"""
|
|
||||||
if msg is not None:
|
|
||||||
for line in msg.split("\n"):
|
|
||||||
if line != "":
|
|
||||||
self.send_msg_final(channel, line, cmd, endl)
|
|
||||||
|
|
||||||
def send_msg_verified(self, sender, channel, msg, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""A more secure way to send messages"""
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"):
|
|
||||||
"""Send a message to all channels on this server"""
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
def disconnect(self):
|
|
||||||
"""Close the socket with the server"""
|
|
||||||
if self.connected:
|
|
||||||
self.stop = True
|
|
||||||
try:
|
try:
|
||||||
self.s.shutdown(socket.SHUT_RDWR)
|
while not self._sending_queue.empty():
|
||||||
except socket.error:
|
self._write(self._sending_queue.get_nowait())
|
||||||
|
_wlist.remove(self)
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.stopping.wait()
|
def _write_select(self, message):
|
||||||
return True
|
"""Send a message to the server safely through select"""
|
||||||
else:
|
self._sending_queue.put(self.format(message))
|
||||||
return False
|
self.logger.debug("Message '%s' appended to Queue", message)
|
||||||
|
if self not in _wlist:
|
||||||
def kill(self):
|
_wlist.append(self)
|
||||||
"""Just stop the main loop, don't close the socket directly"""
|
|
||||||
if self.connected:
|
|
||||||
self.stop = True
|
|
||||||
self.connected = False
|
|
||||||
#Send a message in order to close the socket
|
|
||||||
try:
|
|
||||||
self.s.send(("Bye!\r\n").encode ())
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
self.stopping.wait()
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def launch(self, receive_action, verb=True):
|
|
||||||
"""Connect to the server if it is no yet connected"""
|
|
||||||
self._receive_action = receive_action
|
|
||||||
if not self.connected:
|
|
||||||
self.stop = False
|
|
||||||
self.logger.info("Entering main loop for server")
|
|
||||||
try:
|
|
||||||
self.start()
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
elif verb:
|
|
||||||
print (" Already connected.")
|
|
||||||
|
|
||||||
def treat_msg(self, line, private=False):
|
|
||||||
self._receive_action(self, line, private)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
77
server/socket.py
Normal file
77
server/socket.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Nemubot is a smart and modulable IM bot.
|
||||||
|
# Copyright (C) 2012-2014 nemunaire
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import ssl
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from server import AbstractServer
|
||||||
|
|
||||||
|
class SocketServer(AbstractServer):
|
||||||
|
|
||||||
|
def __init__(self, host, port=6667, password=None, ssl=False):
|
||||||
|
AbstractServer.__init__(self)
|
||||||
|
self.host = host
|
||||||
|
self.port = int(port)
|
||||||
|
self.password = password
|
||||||
|
self.ssl = ssl
|
||||||
|
|
||||||
|
self.socket = None
|
||||||
|
self.readbuffer = b''
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self.socket.fileno() if self.socket else None
|
||||||
|
|
||||||
|
def _open(self):
|
||||||
|
# Create the socket
|
||||||
|
self.socket = socket.socket()
|
||||||
|
|
||||||
|
# Wrap the socket for SSL
|
||||||
|
if self.ssl:
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||||
|
self.socket = ctx.wrap_socket(self.socket)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.socket.connect((self.host, self.port)) #Connect to server
|
||||||
|
self.logger.info("Connected to %s:%d", self.host, self.port)
|
||||||
|
except socket.error as e:
|
||||||
|
self.socket = None
|
||||||
|
self.logger.critical("Unable to connect to %s:%d: %s",
|
||||||
|
self.host, self.port, os.strerror(e.errno))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
if self.socket is not None:
|
||||||
|
self.socket.shutdown(SHUT_RDWR)
|
||||||
|
self.socket.close()
|
||||||
|
self.socket = None
|
||||||
|
|
||||||
|
def _write(self, cnt):
|
||||||
|
self.socket.send(cnt)
|
||||||
|
|
||||||
|
def format(self, txt):
|
||||||
|
return txt.encode() + b'\r\n'
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
raw = self.socket.recv(1024)
|
||||||
|
temp = (self.readbuffer + raw).split(b'\r\n')
|
||||||
|
self.readbuffer = temp.pop()
|
||||||
|
|
||||||
|
for line in temp:
|
||||||
|
yield line
|
Loading…
x
Reference in New Issue
Block a user