Skip to content

Commit dcce02e

Browse files
committed
feat: Use configured DNS name to lookup instance IP address
When a custom DNS name is used to connect to a Cloud SQL instance, the dialer should first attempt to resolve the custom DNS name to an IP address and use that for the connection. If the lookup fails, the dialer should fall back to using the IP address from the instance metadata. Fixes #1362
1 parent ae7db85 commit dcce02e

File tree

5 files changed

+192
-0
lines changed

5 files changed

+192
-0
lines changed

build.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ if [[ ! -d venv ]] ; then
2828
echo "./venv not found. Setting up venv"
2929
python3 -m venv "$PWD/venv"
3030
fi
31+
3132
source "$PWD/venv/bin/activate"
3233

3334
if which pip3 ; then
@@ -135,6 +136,10 @@ function write_e2e_env(){
135136
val=$(gcloud secrets versions access latest --project "$TEST_PROJECT" --secret="$secret_name")
136137
echo "export $env_var_name='$val'"
137138
done
139+
# Aliases for python e2e tests
140+
echo "export POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME=\"\$POSTGRES_CUSTOMER_CAS_DOMAIN_NAME\""
141+
echo "export POSTGRES_IAM_USER=\"\$POSTGRES_USER_IAM\""
142+
echo "export MYSQL_IAM_USER=\"\$MYSQL_USER_IAM\""
138143
} > "$1"
139144

140145
}

google/cloud/sql/connector/connector.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,33 @@ async def connect_async(
390390
# the cache and re-raise the error
391391
await self._remove_cached(str(conn_name), enable_iam_auth)
392392
raise
393+
394+
# If the connector is configured with a custom DNS name, attempt to use
395+
# that DNS name to connect to the instance. Fall back to the metadata IP
396+
# address if the DNS name does not resolve to an IP address.
397+
if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver):
398+
try:
399+
ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name)
400+
if ips:
401+
ip_address = ips[0]
402+
logger.debug(
403+
f"['{instance_connection_string}']: Custom DNS name "
404+
f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', "
405+
"using it to connect"
406+
)
407+
else:
408+
logger.debug(
409+
f"['{instance_connection_string}']: Custom DNS name "
410+
f"'{conn_info.conn_name.domain_name}' resolved but returned no "
411+
f"entries, using '{ip_address}' from instance metadata"
412+
)
413+
except Exception as e:
414+
logger.debug(
415+
f"['{instance_connection_string}']: Custom DNS name "
416+
f"'{conn_info.conn_name.domain_name}' did not resolve to an IP "
417+
f"address: {e}, using '{ip_address}' from instance metadata"
418+
)
419+
393420
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
394421
# format `user` param for automatic IAM database authn
395422
if enable_iam_auth:

google/cloud/sql/connector/resolver.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import List
16+
1517
import dns.asyncresolver
1618

1719
from google.cloud.sql.connector.connection_name import _is_valid_domain
@@ -53,6 +55,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore
5355
)
5456
return conn_name
5557

58+
async def resolve_a_record(self, dns: str) -> List[str]:
59+
try:
60+
# Attempt to query the A records.
61+
records = await super().resolve(dns, "A", raise_on_no_answer=True)
62+
# return IP addresses as strings
63+
return [record.to_text() for record in records]
64+
except Exception:
65+
# On any error, return empty list
66+
return []
67+
5668
async def query_dns(self, dns: str) -> ConnectionName:
5769
try:
5870
# Attempt to query the TXT records.

tests/unit/test_connector.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from google.cloud.sql.connector.exceptions import ConnectorLoopError
3535
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
3636
from google.cloud.sql.connector.instance import RefreshAheadCache
37+
from google.cloud.sql.connector.resolver import DnsResolver
3738

3839

3940
@pytest.mark.asyncio
@@ -548,3 +549,113 @@ def test_connect_closed_connector(
548549
exc_info.value.args[0]
549550
== "Connection attempt failed because the connector has already been closed."
550551
)
552+
553+
554+
@pytest.mark.asyncio
555+
async def test_Connector_connect_async_custom_dns_resolver(
556+
fake_credentials: Credentials, fake_client: CloudSQLClient
557+
) -> None:
558+
"""Test that Connector.connect_async uses custom DNS name resolution."""
559+
560+
# Create a mock DnsResolver that returns a fixed IP
561+
with patch(
562+
"google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record"
563+
) as mock_resolve_a:
564+
mock_resolve_a.return_value = ["1.2.3.4"]
565+
566+
# We also need to patch resolve because DnsResolver.resolve does DNS lookup for TXT
567+
# But we can patch DnsResolver.resolve to return a ConnectionName with domain name
568+
with patch(
569+
"google.cloud.sql.connector.resolver.DnsResolver.resolve"
570+
) as mock_resolve:
571+
# This must return a ConnectionName object with domain_name set
572+
conn_name_with_domain = ConnectionName(
573+
"test-project", "test-region", "test-instance", "db.example.com"
574+
)
575+
mock_resolve.return_value = conn_name_with_domain
576+
577+
async with Connector(
578+
credentials=fake_credentials,
579+
loop=asyncio.get_running_loop(),
580+
resolver=DnsResolver,
581+
) as connector:
582+
connector._client = fake_client
583+
584+
# patch db connection creation
585+
with patch(
586+
"google.cloud.sql.connector.asyncpg.connect"
587+
) as mock_connect:
588+
mock_connect.return_value = True
589+
590+
# Call connect_async
591+
# Use "db.example.com" as instance connection string (resolver will handle it)
592+
connection = await connector.connect_async(
593+
"db.example.com",
594+
"asyncpg",
595+
user="my-user",
596+
password="my-pass",
597+
db="my-db",
598+
)
599+
600+
# Verify mock_connect was called with resolved IP "1.2.3.4"
601+
# The first arg to mock_connect (which patches connector call) is ip_address
602+
args, _ = mock_connect.call_args
603+
assert args[0] == "1.2.3.4"
604+
assert connection is True
605+
606+
607+
@pytest.mark.asyncio
608+
async def test_Connector_connect_async_custom_dns_resolver_fallback(
609+
fake_credentials: Credentials, fake_client: CloudSQLClient
610+
) -> None:
611+
"""Test that Connector.connect_async falls back if DNS resolution fails."""
612+
613+
# Create a mock DnsResolver that returns empty list (failure)
614+
with patch(
615+
"google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record"
616+
) as mock_resolve_a:
617+
mock_resolve_a.return_value = []
618+
619+
with patch(
620+
"google.cloud.sql.connector.resolver.DnsResolver.resolve"
621+
) as mock_resolve:
622+
conn_name_with_domain = ConnectionName(
623+
"test-project", "test-region", "test-instance", "db.example.com"
624+
)
625+
mock_resolve.return_value = conn_name_with_domain
626+
627+
async with Connector(
628+
credentials=fake_credentials,
629+
loop=asyncio.get_running_loop(),
630+
resolver=DnsResolver,
631+
) as connector:
632+
connector._client = fake_client
633+
634+
# Save original IPs to restore later (fake_instance is session-scoped)
635+
original_ips = fake_client.instance.ip_addrs
636+
# Set metadata IP to something specific
637+
fake_client.instance.ip_addrs = {"PRIMARY": "5.6.7.8"}
638+
639+
try:
640+
with patch(
641+
"google.cloud.sql.connector.asyncpg.connect"
642+
) as mock_connect:
643+
mock_connect.return_value = True
644+
645+
connection = await connector.connect_async(
646+
"db.example.com",
647+
"asyncpg",
648+
user="my-user",
649+
password="my-pass",
650+
db="my-db",
651+
)
652+
653+
# Verify mock_connect was called with metadata IP "5.6.7.8"
654+
args, _ = mock_connect.call_args
655+
assert args[0] == "5.6.7.8"
656+
assert connection is True
657+
finally:
658+
# Restore original IPs
659+
fake_client.instance.ip_addrs = original_ips
660+
661+

tests/unit/test_resolver.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,40 @@ async def test_DnsResolver_with_bad_dns_name() -> None:
129129
with pytest.raises(DnsResolutionError) as exc_info:
130130
await resolver.resolve("bad.dns.com")
131131
assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`"
132+
133+
134+
a_record_query_text = """id 1234
135+
opcode QUERY
136+
rcode NOERROR
137+
flags QR AA RD RA
138+
;QUESTION
139+
db.example.com. IN A
140+
;ANSWER
141+
db.example.com. 0 IN A 127.0.0.1
142+
;AUTHORITY
143+
;ADDITIONAL
144+
"""
145+
146+
147+
async def test_DnsResolver_resolve_a_record() -> None:
148+
"""Test DnsResolver resolves A record into IP address."""
149+
with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve:
150+
answer = dns.resolver.Answer(
151+
"db.example.com",
152+
dns.rdatatype.A,
153+
dns.rdataclass.IN,
154+
dns.message.from_text(a_record_query_text),
155+
)
156+
mock_resolve.return_value = answer
157+
resolver = DnsResolver()
158+
result = await resolver.resolve_a_record("db.example.com")
159+
assert result == ["127.0.0.1"]
160+
161+
162+
async def test_DnsResolver_resolve_a_record_empty() -> None:
163+
"""Test DnsResolver resolves A record but gets error."""
164+
with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve:
165+
mock_resolve.side_effect = Exception("DNS Error")
166+
resolver = DnsResolver()
167+
result = await resolver.resolve_a_record("db.example.com")
168+
assert result == []

0 commit comments

Comments
 (0)