Skip to content

Commit 8cd1c3b

Browse files
committed
Custom resolver option
1 parent 487206f commit 8cd1c3b

File tree

5 files changed

+40
-11
lines changed

5 files changed

+40
-11
lines changed

docs/source/driver.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,24 @@ The maximum time to allow for retries to be attempted when using transaction fun
144144
After this time, no more retries will be attempted.
145145
This setting does not terminate running queries.
146146

147+
``resolver``
148+
------------
149+
150+
A custom resolver function to use for DNS resolution.
151+
This function is called with a 2-tuple of (host, port) and should return an iterable of tuples as would be returned from ``getaddrinfo``.
152+
153+
For example::
154+
155+
def my_resolver(socket_address):
156+
if socket_address == ("foo", 9999):
157+
yield "::1", 7687
158+
yield "127.0.0.1", 7687
159+
else:
160+
from socket import gaierror
161+
raise gaierror("Unexpected socket address %r" % socket_address)
162+
163+
driver = GraphDatabase.driver("bolt://foo:9999", auth=("neo4j", "password"), resolver=my_resolver)
164+
147165

148166

149167
Object Lifetime

neo4j/addressing.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,9 @@
2121
from collections import namedtuple
2222
from socket import getaddrinfo, gaierror, SOCK_STREAM, IPPROTO_TCP
2323

24-
from neo4j.compat import urlparse
24+
from neo4j.compat import urlparse, parse_qs
2525
from neo4j.exceptions import AddressError
2626

27-
try:
28-
from urllib.parse import parse_qs
29-
except ImportError:
30-
from urlparse import parse_qs
31-
3227

3328
VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
3429
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"

neo4j/bolt/connection.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,12 +685,15 @@ def connect(address, ssl_context=None, error_handler=None, **config):
685685
a protocol version can be agreed.
686686
"""
687687

688+
last_error = None
688689
# Establish a connection to the host and port specified
689690
# Catches refused connections see:
690691
# https://docs.python.org/2/library/errno.html
691692
log_debug("~~ [RESOLVE] %s", address)
692-
last_error = None
693-
for resolved_address in resolve(address):
693+
resolver = config.get("resolver")
694+
if not callable(resolver):
695+
resolver = resolve
696+
for resolved_address in resolver(address):
694697
log_debug("~~ [RESOLVED] %s -> %s", address, resolved_address)
695698
try:
696699
s = _connect(resolved_address, **config)

neo4j/compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,6 @@ def perf_counter():
122122

123123
# The location of urlparse varies between Python 2 and 3
124124
try:
125-
from urllib.parse import urlparse
125+
from urllib.parse import urlparse, parse_qs
126126
except ImportError:
127-
from urlparse import urlparse
127+
from urlparse import urlparse, parse_qs

test/integration/test_driver.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
# limitations under the License.
1919

2020

21-
from neo4j.v1 import GraphDatabase, ServiceUnavailable
21+
from neo4j.bolt import DEFAULT_PORT
22+
from neo4j.v1 import GraphDatabase, Driver, ServiceUnavailable
2223
from test.integration.tools import IntegrationTestCase
2324

2425

@@ -43,3 +44,15 @@ def test_fail_nicely_when_using_http_port(self):
4344
with self.assertRaises(ServiceUnavailable):
4445
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False):
4546
pass
47+
48+
def test_custom_resolver(self):
49+
50+
def my_resolver(socket_address):
51+
self.assertEqual(socket_address, ("*", DEFAULT_PORT))
52+
yield "99.99.99.99", self.bolt_port # this should be rejected as unable to connect
53+
yield "127.0.0.1", self.bolt_port # this should succeed
54+
55+
with Driver("bolt://*", auth=self.auth_token, resolver=my_resolver) as driver:
56+
with driver.session() as session:
57+
summary = session.run("RETURN 1").summary()
58+
self.assertEqual(summary.server.address, ("127.0.0.1", 7687))

0 commit comments

Comments
 (0)