From 861ca0afddc625da6e19606c86f0819b80965ab1 Mon Sep 17 00:00:00 2001 From: Pierre-Olivier Mercier Date: Tue, 17 Jan 2023 21:55:25 +0100 Subject: [PATCH] Try to connect multiple times (with different servers if any) --- nemubot/__main__.py | 16 +++++++++++----- nemubot/config/server.py | 4 ++-- nemubot/server/socket.py | 5 +++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/nemubot/__main__.py b/nemubot/__main__.py index 4275d95..6a8b265 100644 --- a/nemubot/__main__.py +++ b/nemubot/__main__.py @@ -155,12 +155,18 @@ def main(): # Preset each server in this file for server in config.servers: - srv = server.server(config) # Add the server in the context - if context.add_server(srv): - logger.info("Server '%s' successfully added.", srv.name) - else: - logger.error("Can't add server '%s'.", srv.name) + for i in [0,1,2,3]: + srv = server.server(config, trynb=i) + try: + if context.add_server(srv): + logger.info("Server '%s' successfully added.", srv.name) + else: + logger.error("Can't add server '%s'.", srv.name) + except: + logger.error("Unable to connect to '%s'.", srv.name) + continue + break # Load module and their configuration for mod in config.modules: diff --git a/nemubot/config/server.py b/nemubot/config/server.py index 14ca9a8..17bfaee 100644 --- a/nemubot/config/server.py +++ b/nemubot/config/server.py @@ -33,7 +33,7 @@ class Server: return True - def server(self, parent): + def server(self, parent, trynb=0): from nemubot.server import factory for a in ["nick", "owner", "realname", "encoding"]: @@ -42,4 +42,4 @@ class Server: self.caps += parent.caps - return factory(self.uri, caps=self.caps, channels=self.channels, **self.args) + return factory(self.uri, caps=self.caps, channels=self.channels, trynb=trynb, **self.args) diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index a6be620..bf55bf5 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -79,8 +79,9 @@ class _Socket(AbstractServer): class SocketServer(_Socket): - def __init__(self, host, port, bind=None, **kwargs): - (family, type, proto, canonname, self._sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0] + def __init__(self, host, port, bind=None, trynb=0, **kwargs): + destlist = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP) + (family, type, proto, canonname, self._sockaddr) = destlist[trynb%len(destlist)] super().__init__(fdClass=socket.socket, family=family, type=type, proto=proto, **kwargs)