Skip to content

Commit 949014c

Browse files
committed
Chained resolution
1 parent 1565676 commit 949014c

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

neo4j/addressing.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,46 @@ def parse_routing_context(cls, uri):
9797
return context
9898

9999

100-
def resolve(socket_address):
101-
try:
102-
info = getaddrinfo(socket_address[0], socket_address[1], 0, SOCK_STREAM, IPPROTO_TCP)
103-
except gaierror:
104-
raise AddressError("Cannot resolve address {!r}".format(socket_address[0]))
105-
else:
106-
addresses = []
107-
for _, _, _, _, address in info:
108-
if len(address) == 4 and address[3] != 0:
109-
# skip any IPv6 addresses with a non-zero scope id
110-
# as these appear to cause problems on some platforms
111-
continue
112-
addresses.append(address)
113-
return addresses
100+
class Resolver(object):
101+
""" A Resolver instance stores a list of addresses, each in a tuple, and
102+
provides methods to perform resolution on these, thereby replacing them
103+
with the resolved values.
104+
"""
105+
106+
def __init__(self, custom_resolver=None):
107+
self.addresses = []
108+
self.custom_resolver = custom_resolver
109+
110+
def custom_resolve(self):
111+
""" If a custom resolver is defined, perform custom resolution on
112+
the contained addresses.
113+
114+
:return:
115+
"""
116+
if not callable(self.custom_resolver):
117+
return
118+
new_addresses = []
119+
for address in self.addresses:
120+
for new_address in self.custom_resolver(address):
121+
new_addresses.append(new_address)
122+
self.addresses = new_addresses
123+
124+
def dns_resolve(self):
125+
""" Perform DNS resolution on the contained addresses.
126+
127+
:return:
128+
"""
129+
new_addresses = []
130+
for address in self.addresses:
131+
try:
132+
info = getaddrinfo(address[0], address[1], 0, SOCK_STREAM, IPPROTO_TCP)
133+
except gaierror:
134+
raise AddressError("Cannot resolve address {!r}".format(address))
135+
else:
136+
for _, _, _, _, address in info:
137+
if len(address) == 4 and address[3] != 0:
138+
# skip any IPv6 addresses with a non-zero scope id
139+
# as these appear to cause problems on some platforms
140+
continue
141+
new_addresses.append(address)
142+
self.addresses = new_addresses

neo4j/bolt/connection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from struct import pack as struct_pack, unpack as struct_unpack
3434
from threading import RLock, Condition
3535

36-
from neo4j.addressing import SocketAddress, resolve
36+
from neo4j.addressing import SocketAddress, Resolver
3737
from neo4j.bolt.cert import KNOWN_HOSTS
3838
from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse
3939
from neo4j.compat.ssl import SSL_AVAILABLE, HAS_SNI, SSLError
@@ -690,10 +690,11 @@ def connect(address, ssl_context=None, error_handler=None, **config):
690690
# Catches refused connections see:
691691
# https://docs.python.org/2/library/errno.html
692692
log_debug("~~ [RESOLVE] %s", address)
693-
resolver = config.get("resolver")
694-
if not callable(resolver):
695-
resolver = resolve
696-
for resolved_address in resolver(address):
693+
resolver = Resolver(custom_resolver=config.get("resolver"))
694+
resolver.addresses.append(address)
695+
resolver.custom_resolve()
696+
resolver.dns_resolve()
697+
for resolved_address in resolver.addresses:
697698
log_debug("~~ [RESOLVED] %s -> %s", address, resolved_address)
698699
try:
699700
s = _connect(resolved_address, **config)

0 commit comments

Comments
 (0)