@@ -7,6 +7,7 @@ import mock
77
88import grpc
99from grpc.experimental import aio
10+ import json
1011import math
1112import pytest
1213from proto.marshal.rules.dates import DurationRule, TimestampRule
@@ -1187,6 +1188,7 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
11871188 {% if "next_page_token" in method .output .fields .values ()|map (attribute ='name' ) and not method .paged_result_field %}
11881189 {# Cheeser assertion to force code coverage for bad paginated methods #}
11891190 assert response.raw_page is response
1191+
11901192 {% endif %}
11911193
11921194 # Establish that the response is the type that we expect.
@@ -1210,6 +1212,130 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
12101212 {% endif %}
12111213
12121214
1215+ {% if method .input .required_fields %}
1216+ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}):
1217+ transport_class = transports.{{ service.rest_transport_name }}
1218+
1219+ request_init = {}
1220+ {% for req_field in method .input .required_fields if req_field .is_primitive %}
1221+ {% if req_field .field_pb .default_value is string %}
1222+ request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
1223+ {% else %}
1224+ request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
1225+ {% endif %} {# default is str #}
1226+ {% endfor %}
1227+ request = request_type(request_init)
1228+ jsonified_request = json.loads(request_type.to_json(
1229+ request,
1230+ including_default_value_fields=False,
1231+ use_integers_for_enums=False
1232+ ))
1233+
1234+ # verify fields with default values are dropped
1235+ {% for req_field in method .input .required_fields if req_field .is_primitive %}
1236+ {% set field_name = req_field .name | camel_case %}
1237+ assert "{{ field_name }}" not in jsonified_request
1238+ {% endfor %}
1239+
1240+ unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
1241+ jsonified_request.update(unset_fields)
1242+
1243+ # verify required fields with default values are now present
1244+ {% for req_field in method .input .required_fields if req_field .is_primitive %}
1245+ {% set field_name = req_field .name | camel_case %}
1246+ assert "{{ field_name }}" in jsonified_request
1247+ assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
1248+ {% endfor %}
1249+
1250+ {% for req_field in method .input .required_fields if req_field .is_primitive %}
1251+ {% set field_name = req_field .name | camel_case %}
1252+ {% set mock_value = req_field .primitive_mock_as_str () %}
1253+ jsonified_request["{{ field_name }}"] = {{ mock_value }}
1254+ {% endfor %}
1255+
1256+ unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request)
1257+ jsonified_request.update(unset_fields)
1258+
1259+ # verify required fields with non-default values are left alone
1260+ {% for req_field in method .input .required_fields if req_field .is_primitive %}
1261+ {% set field_name = req_field .name | camel_case %}
1262+ {% set mock_value = req_field .primitive_mock_as_str () %}
1263+ assert "{{ field_name }}" in jsonified_request
1264+ assert jsonified_request["{{ field_name }}"] == {{ mock_value }}
1265+ {% endfor %}
1266+
1267+
1268+ client = {{ service.client_name }}(
1269+ credentials=ga_credentials.AnonymousCredentials(),
1270+ transport='rest',
1271+ )
1272+ request = request_type(request_init)
1273+
1274+ # Designate an appropriate value for the returned response.
1275+ {% if method .void %}
1276+ return_value = None
1277+ {% elif method .lro %}
1278+ return_value = operations_pb2.Operation(name='operations/spam')
1279+ {% elif method .server_streaming %}
1280+ return_value = iter([{{ method.output.ident }}()])
1281+ {% else %}
1282+ return_value = {{ method.output.ident }}()
1283+ {% endif %}
1284+ # Mock the http request call within the method and fake a response.
1285+ with mock.patch.object(Session, 'request') as req:
1286+ # We need to mock transcode() because providing default values
1287+ # for required fields will fail the real version if the http_options
1288+ # expect actual values for those fields.
1289+ with mock.patch.object(path_template, 'transcode') as transcode:
1290+ # A uri without fields and an empty body will force all the
1291+ # request fields to show up in the query_params.
1292+ transcode_result = {
1293+ 'uri': 'v1/sample_method',
1294+ 'method': "{{ method.http_options[0] .method }}",
1295+ 'query_params': request_init,
1296+ }
1297+ {% if method .http_options [0].body %}
1298+ transcode_result['body'] = {}
1299+ {% endif %}
1300+ transcode.return_value = transcode_result
1301+
1302+ response_value = Response()
1303+ response_value.status_code = 200
1304+ {% if method .void %}
1305+ json_return_value = ''
1306+ {% elif method .lro %}
1307+ json_return_value = json_format.MessageToJson(return_value)
1308+ {% else %}
1309+ json_return_value = {{ method.output.ident }}.to_json(return_value)
1310+ {% endif %}
1311+ response_value._content = json_return_value.encode('UTF-8')
1312+ req.return_value = response_value
1313+
1314+ {% if method .client_streaming %}
1315+ response = client.{{ method.name|snake_case }}(iter(requests))
1316+ {% else %}
1317+ response = client.{{ method_name }}(request)
1318+ {% endif %}
1319+
1320+ expected_params = [
1321+ {% for req_field in method .input .required_fields if req_field .is_primitive %}
1322+ (
1323+ "{{ req_field.name }}",
1324+ {% if req_field .field_pb .default_value is string %}
1325+ "{{ req_field.field_pb.default_value }}"
1326+ {% else %}
1327+ {{ req_field.field_pb.default_value }}
1328+ {% endif %} {# default is str #}
1329+ )
1330+ {% endfor %}
1331+ ]
1332+ actual_params = req.call_args.kwargs['params']
1333+ assert expected_params == actual_params
1334+
1335+
1336+ {% endif %} {# required_fields #}
1337+
1338+
12131339def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
12141340 client = {{ service.client_name }}(
12151341 credentials=ga_credentials.AnonymousCredentials(),
@@ -1325,9 +1451,10 @@ def test_{{ method_name }}_rest_flattened_error(transport: str = 'rest'):
13251451
13261452
13271453{% if method .paged_result_field %}
1328- def test_{{ method_name }}_rest_pager():
1454+ def test_{{ method_name }}_rest_pager(transport: str = 'rest' ):
13291455 client = {{ service.client_name }}(
13301456 credentials=ga_credentials.AnonymousCredentials(),
1457+ transport=transport,
13311458 )
13321459
13331460 # Mock the http request call within the method and fake a response.
@@ -1446,25 +1573,35 @@ def test_{{ method_name }}_rest_error():
14461573 credentials=ga_credentials.AnonymousCredentials(),
14471574 transport='rest'
14481575 )
1449- {% - if not method .http_options %}
1450- # Since a `google.api.http` annotation is required for using a rest transport
1451- # method, this should error.
1452- with pytest.raises(RuntimeError) as runtime_error:
1453- client.{{ method_name }}({})
1454- assert ('Cannot define a method without a valid `google.api.http` annotation.'
1455- in str(runtime_error.value))
1456- {% - else %}
14571576
14581577 # TODO(yon-mg): Remove when this method has a working implementation
14591578 # or testing straegy
14601579 with pytest.raises(NotImplementedError):
14611580 client.{{ method_name }}({})
14621581
1463- {% - endif %}
14641582
1465- {% endif %}{% endwith %} {# method_name #}
1583+ {% endif %} {# not streaming #} { % endwith %} {# method_name #}
14661584
14671585{% endfor -%} {#- method in methods for rest #}
1586+
1587+ {% for method in service .methods .values () if 'rest' in opts .transport and
1588+ not method .http_options %}{% with method_name = method .name |snake_case + "_unary" if method .operation_service else method .name |snake_case %}
1589+ def test_{{ method_name }}_rest_error():
1590+ client = {{ service.client_name }}(
1591+ credentials=ga_credentials.AnonymousCredentials(),
1592+ transport='rest'
1593+ )
1594+ # Since a `google.api.http` annotation is required for using a rest transport
1595+ # method, this should error.
1596+ with pytest.raises(RuntimeError) as runtime_error:
1597+ client.{{ method_name }}({})
1598+ assert ("Cannot define a method without a valid 'google.api.http' annotation."
1599+ in str(runtime_error.value))
1600+
1601+
1602+ {% endwith %} {# method_name #}
1603+ {% endfor %} {# for methods without http_options #}
1604+
14681605def test_credentials_transport_error():
14691606 # It is an error to provide credentials and a transport instance.
14701607 transport = transports.{{ service.name }}{{ opts.transport[0] .capitalize() }}Transport(
@@ -1758,8 +1895,7 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl
17581895 mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)
17591896
17601897
1761- {# TODO(kbandes): re-enable this code when LRO is implmented for rest #}
1762- {% if False and service .has_lro -%}
1898+ {% if service .has_lro -%}
17631899def test_{{ service.name|snake_case }}_rest_lro_client():
17641900 client = {{ service.client_name }}(
17651901 credentials=ga_credentials.AnonymousCredentials(),
@@ -1770,7 +1906,7 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
17701906 # Ensure that we have a api-core operations client.
17711907 assert isinstance(
17721908 transport.operations_client,
1773- operations_v1.OperationsClient ,
1909+ operations_v1.AbstractOperationsClient ,
17741910 )
17751911
17761912 # Ensure that subsequent calls to the property send the exact same object.
0 commit comments