Skip to content

Commit 0705d9c

Browse files
kbandesKenneth Bandesparthea
authored
fix: ensure rest unit tests have complete coverage (googleapis#1098)
* fix: rest paging and lro client tests weren't working. * fix: fix coverage gaps in rest unit tests. * fix: refactor required fields code to move update out of static method. * fix: test that api method with required fields handles them correctly. * fix: removed extra parens from an expression in a test. Co-authored-by: Kenneth Bandes <kbandes@google.com> Co-authored-by: Anthonios Partheniou <partheniou@google.com>
1 parent 956078f commit 0705d9c

File tree

8 files changed

+174
-39
lines changed

8 files changed

+174
-39
lines changed

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,22 @@ class {{service.name}}RestTransport({{service.name}}Transport):
166166
{% endif %}{# service.has_lro #}
167167
{% for method in service.methods.values() %}
168168
{%- if method.http_options and not (method.server_streaming or method.client_streaming) %}
169+
170+
{% if method.input.required_fields %}
171+
__{{ method.name | snake_case }}_required_fields_default_values = {
172+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
173+
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %}{# default is str #}
174+
{% endfor %}
175+
}
176+
177+
178+
@staticmethod
179+
def _{{ method.name | snake_case }}_get_unset_required_fields(message_dict):
180+
return {k: v for k, v in {{service.name}}RestTransport.__{{ method.name | snake_case }}_required_fields_default_values.items() if k not in message_dict}
181+
182+
183+
{% endif %}{# required fields #}
184+
169185
def _{{method.name | snake_case}}(self,
170186
request: {{method.input.ident}}, *,
171187
retry: OptionalRetry=gapic_v1.method.DEFAULT,
@@ -206,21 +222,6 @@ class {{service.name}}RestTransport({{service.name}}Transport):
206222
{% endfor %}
207223
]
208224

209-
{% if method.input.required_fields %}
210-
required_fields = [
211-
# (snake_case_name, camel_case_name)
212-
{% for req_field in method.input.required_fields %}
213-
{% if req_field.is_primitive %}
214-
(
215-
"{{ req_field.name | snake_case }}",
216-
"{{ req_field.name | camel_case }}"
217-
),
218-
{% endif %}{# is primitive #}
219-
{% endfor %}{# required fields #}
220-
]
221-
222-
{% endif %}
223-
224225
request_kwargs = {{method.input.ident}}.to_dict(request)
225226
transcoded_request = path_template.transcode(
226227
http_options, **request_kwargs)
@@ -254,16 +255,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
254255
))
255256

256257
{% if method.input.required_fields %}
257-
# Ensure required fields have values in query_params.
258-
# If a required field has a default value, it can get lost
259-
# by the to_json call above.
260-
orig_query_params = transcoded_request["query_params"]
261-
for snake_case_name, camel_case_name in required_fields:
262-
if snake_case_name in orig_query_params:
263-
if camel_case_name not in query_params:
264-
query_params[camel_case_name] = orig_query_params[snake_case_name]
265-
266-
{% endif %}
258+
query_params.update(self._{{ method.name | snake_case }}_get_unset_required_fields(query_params))
259+
{% endif %}{# required fields #}
267260

268261
# Send the request
269262
headers = dict(metadata)

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 150 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import mock
77

88
import grpc
99
from grpc.experimental import aio
10+
import json
1011
import math
1112
import pytest
1213
from proto.marshal.rules.dates import DurationRule, TimestampRule
@@ -1187,6 +1188,7 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
11871188
{% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %}
11881189
{# Cheeser assertion to force code coverage for bad paginated methods #}
11891190
assert response.raw_page is response
1191+
11901192
{% endif %}
11911193

11921194
# Establish that the response is the type that we expect.
@@ -1210,6 +1212,130 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
12101212
{% endif %}
12111213

12121214

1215+
{% if method.input.required_fields %}
1216+
def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}):
1217+
transport_class = transports.{{ service.rest_transport_name }}
1218+
1219+
request_init = {}
1220+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1221+
{% if req_field.field_pb.default_value is string %}
1222+
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
1223+
{% else %}
1224+
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
1225+
{% endif %}{# default is str #}
1226+
{% endfor %}
1227+
request = request_type(request_init)
1228+
jsonified_request = json.loads(request_type.to_json(
1229+
request,
1230+
including_default_value_fields=False,
1231+
use_integers_for_enums=False
1232+
))
1233+
1234+
# verify fields with default values are dropped
1235+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1236+
{% set field_name = req_field.name | camel_case %}
1237+
assert "{{ field_name }}" not in jsonified_request
1238+
{% endfor %}
1239+
1240+
unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
1241+
jsonified_request.update(unset_fields)
1242+
1243+
# verify required fields with default values are now present
1244+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1245+
{% set field_name = req_field.name | camel_case %}
1246+
assert "{{ field_name }}" in jsonified_request
1247+
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
1248+
{% endfor %}
1249+
1250+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1251+
{% set field_name = req_field.name | camel_case %}
1252+
{% set mock_value = req_field.primitive_mock_as_str() %}
1253+
jsonified_request["{{ field_name }}"] = {{ mock_value }}
1254+
{% endfor %}
1255+
1256+
unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
1257+
jsonified_request.update(unset_fields)
1258+
1259+
# verify required fields with non-default values are left alone
1260+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1261+
{% set field_name = req_field.name | camel_case %}
1262+
{% set mock_value = req_field.primitive_mock_as_str() %}
1263+
assert "{{ field_name }}" in jsonified_request
1264+
assert jsonified_request["{{ field_name }}"] == {{ mock_value }}
1265+
{% endfor %}
1266+
1267+
1268+
client = {{ service.client_name }}(
1269+
credentials=ga_credentials.AnonymousCredentials(),
1270+
transport='rest',
1271+
)
1272+
request = request_type(request_init)
1273+
1274+
# Designate an appropriate value for the returned response.
1275+
{% if method.void %}
1276+
return_value = None
1277+
{% elif method.lro %}
1278+
return_value = operations_pb2.Operation(name='operations/spam')
1279+
{% elif method.server_streaming %}
1280+
return_value = iter([{{ method.output.ident }}()])
1281+
{% else %}
1282+
return_value = {{ method.output.ident }}()
1283+
{% endif %}
1284+
# Mock the http request call within the method and fake a response.
1285+
with mock.patch.object(Session, 'request') as req:
1286+
# We need to mock transcode() because providing default values
1287+
# for required fields will fail the real version if the http_options
1288+
# expect actual values for those fields.
1289+
with mock.patch.object(path_template, 'transcode') as transcode:
1290+
# A uri without fields and an empty body will force all the
1291+
# request fields to show up in the query_params.
1292+
transcode_result = {
1293+
'uri': 'v1/sample_method',
1294+
'method': "{{ method.http_options[0].method }}",
1295+
'query_params': request_init,
1296+
}
1297+
{% if method.http_options[0].body %}
1298+
transcode_result['body'] = {}
1299+
{% endif %}
1300+
transcode.return_value = transcode_result
1301+
1302+
response_value = Response()
1303+
response_value.status_code = 200
1304+
{% if method.void %}
1305+
json_return_value = ''
1306+
{% elif method.lro %}
1307+
json_return_value = json_format.MessageToJson(return_value)
1308+
{% else %}
1309+
json_return_value = {{ method.output.ident }}.to_json(return_value)
1310+
{% endif %}
1311+
response_value._content = json_return_value.encode('UTF-8')
1312+
req.return_value = response_value
1313+
1314+
{% if method.client_streaming %}
1315+
response = client.{{ method.name|snake_case }}(iter(requests))
1316+
{% else %}
1317+
response = client.{{ method_name }}(request)
1318+
{% endif %}
1319+
1320+
expected_params = [
1321+
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1322+
(
1323+
"{{ req_field.name }}",
1324+
{% if req_field.field_pb.default_value is string %}
1325+
"{{ req_field.field_pb.default_value }}"
1326+
{% else %}
1327+
{{ req_field.field_pb.default_value }}
1328+
{% endif %}{# default is str #}
1329+
)
1330+
{% endfor %}
1331+
]
1332+
actual_params = req.call_args.kwargs['params']
1333+
assert expected_params == actual_params
1334+
1335+
1336+
{% endif %}{# required_fields #}
1337+
1338+
12131339
def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
12141340
client = {{ service.client_name }}(
12151341
credentials=ga_credentials.AnonymousCredentials(),
@@ -1325,9 +1451,10 @@ def test_{{ method_name }}_rest_flattened_error(transport: str = 'rest'):
13251451

13261452

13271453
{% if method.paged_result_field %}
1328-
def test_{{ method_name }}_rest_pager():
1454+
def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
13291455
client = {{ service.client_name }}(
13301456
credentials=ga_credentials.AnonymousCredentials(),
1457+
transport=transport,
13311458
)
13321459

13331460
# Mock the http request call within the method and fake a response.
@@ -1446,25 +1573,35 @@ def test_{{ method_name }}_rest_error():
14461573
credentials=ga_credentials.AnonymousCredentials(),
14471574
transport='rest'
14481575
)
1449-
{%- if not method.http_options %}
1450-
# Since a `google.api.http` annotation is required for using a rest transport
1451-
# method, this should error.
1452-
with pytest.raises(RuntimeError) as runtime_error:
1453-
client.{{ method_name }}({})
1454-
assert ('Cannot define a method without a valid `google.api.http` annotation.'
1455-
in str(runtime_error.value))
1456-
{%- else %}
14571576

14581577
# TODO(yon-mg): Remove when this method has a working implementation
14591578
# or testing straegy
14601579
with pytest.raises(NotImplementedError):
14611580
client.{{ method_name }}({})
14621581

1463-
{%- endif %}
14641582

1465-
{% endif %}{% endwith %}{# method_name #}
1583+
{% endif %}{# not streaming #}{% endwith %}{# method_name #}
14661584

14671585
{% endfor -%} {#- method in methods for rest #}
1586+
1587+
{% for method in service.methods.values() if 'rest' in opts.transport and
1588+
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
1589+
def test_{{ method_name }}_rest_error():
1590+
client = {{ service.client_name }}(
1591+
credentials=ga_credentials.AnonymousCredentials(),
1592+
transport='rest'
1593+
)
1594+
# Since a `google.api.http` annotation is required for using a rest transport
1595+
# method, this should error.
1596+
with pytest.raises(RuntimeError) as runtime_error:
1597+
client.{{ method_name }}({})
1598+
assert ("Cannot define a method without a valid 'google.api.http' annotation."
1599+
in str(runtime_error.value))
1600+
1601+
1602+
{% endwith %}{# method_name #}
1603+
{% endfor %}{# for methods without http_options #}
1604+
14681605
def test_credentials_transport_error():
14691606
# It is an error to provide credentials and a transport instance.
14701607
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(
@@ -1758,8 +1895,7 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl
17581895
mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)
17591896

17601897

1761-
{# TODO(kbandes): re-enable this code when LRO is implmented for rest #}
1762-
{% if False and service.has_lro -%}
1898+
{% if service.has_lro -%}
17631899
def test_{{ service.name|snake_case }}_rest_lro_client():
17641900
client = {{ service.client_name }}(
17651901
credentials=ga_credentials.AnonymousCredentials(),
@@ -1770,7 +1906,7 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
17701906
# Ensure that we have a api-core operations client.
17711907
assert isinstance(
17721908
transport.operations_client,
1773-
operations_v1.OperationsClient,
1909+
operations_v1.AbstractOperationsClient,
17741910
)
17751911

17761912
# Ensure that subsequent calls to the property send the exact same object.

tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
from grpc.experimental import aio
21+
import json
2122
import math
2223
import pytest
2324
from proto.marshal.rules.dates import DurationRule, TimestampRule

tests/integration/goldens/credentials/tests/unit/gapic/credentials_v1/test_iam_credentials.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
from grpc.experimental import aio
21+
import json
2122
import math
2223
import pytest
2324
from proto.marshal.rules.dates import DurationRule, TimestampRule

tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_config_service_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
from grpc.experimental import aio
21+
import json
2122
import math
2223
import pytest
2324
from proto.marshal.rules.dates import DurationRule, TimestampRule

tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_logging_service_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
from grpc.experimental import aio
21+
import json
2122
import math
2223
import pytest
2324
from proto.marshal.rules.dates import DurationRule, TimestampRule

tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_metrics_service_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
from grpc.experimental import aio
21+
import json
2122
import math
2223
import pytest
2324
from proto.marshal.rules.dates import DurationRule, TimestampRule

tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import grpc
2020
from grpc.experimental import aio
21+
import json
2122
import math
2223
import pytest
2324
from proto.marshal.rules.dates import DurationRule, TimestampRule

0 commit comments

Comments
 (0)