Use getaddrinfo to create the right socket

This commit is contained in:
nemunaire 2016-08-10 23:56:50 +02:00
parent 205a39ad70
commit 5f89428562
3 changed files with 23 additions and 24 deletions

View File

@ -16,6 +16,7 @@
from datetime import datetime from datetime import datetime
import re import re
import socket
from nemubot.channel import Channel from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.message.printer.IRC import IRC as IRCPrinter
@ -240,7 +241,7 @@ class _IRC:
if self.capabilities is not None: if self.capabilities is not None:
self.write("CAP LS") self.write("CAP LS")
self.write("NICK :" + self.nick) self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) self.write("USER %s %s bla :%s" % (self.username, socket.getfqdn(), self.realname))
def close(self): def close(self):

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket
import unittest import unittest
from nemubot.server import factory from nemubot.server import factory
@ -27,25 +28,31 @@ class TestFactory(unittest.TestCase):
# <host>: If omitted, the client must connect to a prespecified default IRC server. # <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///") server = factory("irc:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server._sockaddr,
socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs:///") server = factory("ircs:///")
self.assertIsInstance(server, IRCSServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server._sockaddr,
socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host1") server = factory("irc://freenode.net")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1") self.assertEqual(server._sockaddr,
socket.getaddrinfo("freenode.net", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host2:6667") server = factory("irc://freenode.org:1234")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2") self.assertEqual(server._sockaddr,
self.assertEqual(server.port, 6667) socket.getaddrinfo("freenode.org", 1234, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs://host3:194/") server = factory("ircs://nemunai.re:194/")
self.assertIsInstance(server, IRCSServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "host3") self.assertEqual(server._sockaddr,
self.assertEqual(server.port, 194) socket.getaddrinfo("nemunai.re", 194, proto=socket.IPPROTO_TCP)[0][4])
with self.assertRaises(socket.gaierror):
factory("irc://_nonexistent.nemunai.re")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -81,29 +81,20 @@ class _Socket(AbstractServer):
class _SocketServer(_Socket): class _SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs): def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs) (family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port)[0]
assert(host is not None)
assert(isinstance(port, int))
if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs: if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs:
kwargs["server_hostname"] = host kwargs["server_hostname"] = host
super().__init__(family=family, type=type, proto=proto, **kwargs) super().__init__(family=family, type=type, proto=proto, **kwargs)
self._host = host self._sockaddr = sockaddr
self._port = port
self._bind = bind self._bind = bind
@property
def host(self):
return self._host
def connect(self): def connect(self):
self.logger.info("Connection to %s:%d", self._host, self._port) self.logger.info("Connection to %s:%d", *self._sockaddr[:2])
super().connect((self._host, self._port)) super().connect(self._sockaddr)
if self._bind: if self._bind:
super().bind(self._bind) super().bind(self._bind)