Use getaddrinfo to create the right socket

This commit is contained in:
nemunaire 2016-08-10 23:56:50 +02:00
parent 205a39ad70
commit 5f89428562
3 changed files with 23 additions and 24 deletions

View File

@ -16,6 +16,7 @@
from datetime import datetime
import re
import socket
from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter
@ -240,7 +241,7 @@ class _IRC:
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
self.write("USER %s %s bla :%s" % (self.username, socket.getfqdn(), self.realname))
def close(self):

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket
import unittest
from nemubot.server import factory
@ -27,25 +28,31 @@ class TestFactory(unittest.TestCase):
# <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost")
self.assertEqual(server._sockaddr,
socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs:///")
self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "localhost")
self.assertEqual(server._sockaddr,
socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host1")
server = factory("irc://freenode.net")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1")
self.assertEqual(server._sockaddr,
socket.getaddrinfo("freenode.net", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host2:6667")
server = factory("irc://freenode.org:1234")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2")
self.assertEqual(server.port, 6667)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("freenode.org", 1234, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs://host3:194/")
server = factory("ircs://nemunai.re:194/")
self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "host3")
self.assertEqual(server.port, 194)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("nemunai.re", 194, proto=socket.IPPROTO_TCP)[0][4])
with self.assertRaises(socket.gaierror):
factory("irc://_nonexistent.nemunai.re")
if __name__ == '__main__':

View File

@ -81,29 +81,20 @@ class _Socket(AbstractServer):
class _SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs)
assert(host is not None)
assert(isinstance(port, int))
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port)[0]
if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs:
kwargs["server_hostname"] = host
super().__init__(family=family, type=type, proto=proto, **kwargs)
self._host = host
self._port = port
self._sockaddr = sockaddr
self._bind = bind
@property
def host(self):
return self._host
def connect(self):
self.logger.info("Connection to %s:%d", self._host, self._port)
super().connect((self._host, self._port))
self.logger.info("Connection to %s:%d", *self._sockaddr[:2])
super().connect(self._sockaddr)
if self._bind:
super().bind(self._bind)