Skip to content

Commit 8c41641

Browse files
authored
chore: Refactor OpenApiEditor and SwaggerEditor to reduce duplicate code (#2682)
1 parent 765d402 commit 8c41641

File tree

7 files changed

+322
-473
lines changed

7 files changed

+322
-473
lines changed

samtranslator/model/api/api_generator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def _add_endpoint_extension(self): # type: ignore[no-untyped-def]
322322
raise InvalidResourceException(
323323
self.logical_id, "DisableExecuteApiEndpoint works only within 'DefinitionBody' property."
324324
)
325-
editor = SwaggerEditor(self.definition_body) # type: ignore[no-untyped-call]
325+
editor = SwaggerEditor(self.definition_body)
326326
editor.add_disable_execute_api_endpoint_extension(self.disable_execute_api_endpoint) # type: ignore[no-untyped-call]
327327
self.definition_body = editor.swagger
328328

@@ -644,7 +644,7 @@ def _add_cors(self): # type: ignore[no-untyped-def]
644644
else:
645645
raise InvalidResourceException(self.logical_id, INVALID_ERROR)
646646

647-
if not SwaggerEditor.is_valid(self.definition_body): # type: ignore[no-untyped-call]
647+
if not SwaggerEditor.is_valid(self.definition_body):
648648
raise InvalidResourceException(
649649
self.logical_id,
650650
"Unable to add Cors configuration because "
@@ -659,7 +659,7 @@ def _add_cors(self): # type: ignore[no-untyped-def]
659659
"'AllowOrigin' is \"'*'\" or not set",
660660
)
661661

662-
editor = SwaggerEditor(self.definition_body) # type: ignore[no-untyped-call]
662+
editor = SwaggerEditor(self.definition_body)
663663
for path in editor.iter_on_path():
664664
try:
665665
editor.add_cors( # type: ignore[no-untyped-call]
@@ -688,7 +688,7 @@ def _add_binary_media_types(self): # type: ignore[no-untyped-def]
688688
if self.binary_media and not self.definition_body:
689689
return
690690

691-
editor = SwaggerEditor(self.definition_body) # type: ignore[no-untyped-call]
691+
editor = SwaggerEditor(self.definition_body)
692692
editor.add_binary_media_types(self.binary_media) # type: ignore[no-untyped-call]
693693

694694
# Assign the Swagger back to template
@@ -711,13 +711,13 @@ def _add_auth(self): # type: ignore[no-untyped-def]
711711
if not all(key in AuthProperties._fields for key in self.auth.keys()):
712712
raise InvalidResourceException(self.logical_id, "Invalid value for 'Auth' property")
713713

714-
if not SwaggerEditor.is_valid(self.definition_body): # type: ignore[no-untyped-call]
714+
if not SwaggerEditor.is_valid(self.definition_body):
715715
raise InvalidResourceException(
716716
self.logical_id,
717717
"Unable to add Auth configuration because "
718718
"'DefinitionBody' does not contain a valid Swagger definition.",
719719
)
720-
swagger_editor = SwaggerEditor(self.definition_body) # type: ignore[no-untyped-call]
720+
swagger_editor = SwaggerEditor(self.definition_body)
721721
auth_properties = AuthProperties(**self.auth)
722722
authorizers = self._get_authorizers(auth_properties.Authorizers, auth_properties.DefaultAuthorizer) # type: ignore[no-untyped-call]
723723

@@ -731,8 +731,8 @@ def _add_auth(self): # type: ignore[no-untyped-def]
731731
)
732732

733733
if auth_properties.ApiKeyRequired:
734-
swagger_editor.add_apikey_security_definition() # type: ignore[no-untyped-call]
735-
self._set_default_apikey_required(swagger_editor) # type: ignore[no-untyped-call]
734+
swagger_editor.add_apikey_security_definition()
735+
self._set_default_apikey_required(swagger_editor)
736736

737737
if auth_properties.ResourcePolicy:
738738
SwaggerEditor.validate_is_dict(
@@ -946,14 +946,14 @@ def _add_gateway_responses(self): # type: ignore[no-untyped-def]
946946
),
947947
)
948948

949-
if not SwaggerEditor.is_valid(self.definition_body): # type: ignore[no-untyped-call]
949+
if not SwaggerEditor.is_valid(self.definition_body):
950950
raise InvalidResourceException(
951951
self.logical_id,
952952
"Unable to add Auth configuration because "
953953
"'DefinitionBody' does not contain a valid Swagger definition.",
954954
)
955955

956-
swagger_editor = SwaggerEditor(self.definition_body) # type: ignore[no-untyped-call]
956+
swagger_editor = SwaggerEditor(self.definition_body)
957957

958958
# The dicts below will eventually become part of swagger/openapi definition, thus requires using Py27Dict()
959959
gateway_responses = Py27Dict()
@@ -992,7 +992,7 @@ def _add_models(self): # type: ignore[no-untyped-def]
992992
self.logical_id, "Models works only with inline Swagger specified in 'DefinitionBody' property."
993993
)
994994

995-
if not SwaggerEditor.is_valid(self.definition_body): # type: ignore[no-untyped-call]
995+
if not SwaggerEditor.is_valid(self.definition_body):
996996
raise InvalidResourceException(
997997
self.logical_id,
998998
"Unable to add Models definitions because "
@@ -1002,7 +1002,7 @@ def _add_models(self): # type: ignore[no-untyped-def]
10021002
if not all(isinstance(model, dict) for model in self.models.values()):
10031003
raise InvalidResourceException(self.logical_id, "Invalid value for 'Models' property")
10041004

1005-
swagger_editor = SwaggerEditor(self.definition_body) # type: ignore[no-untyped-call]
1005+
swagger_editor = SwaggerEditor(self.definition_body)
10061006
swagger_editor.add_models(self.models) # type: ignore[no-untyped-call]
10071007

10081008
# Assign the Swagger back to template
@@ -1023,7 +1023,7 @@ def _openapi_postprocess(self, definition_body): # type: ignore[no-untyped-def]
10231023
self.open_api_version = definition_body.get("openapi")
10241024

10251025
if self.open_api_version and SwaggerEditor.safe_compare_regex_with_string(
1026-
SwaggerEditor.get_openapi_version_3_regex(), self.open_api_version
1026+
SwaggerEditor._OPENAPI_VERSION_3_REGEX, self.open_api_version
10271027
):
10281028
if definition_body.get("securityDefinitions"):
10291029
components = definition_body.get("components", Py27Dict())
@@ -1208,7 +1208,7 @@ def _set_default_authorizer(
12081208
add_default_auth_to_preflight=add_default_auth_to_preflight,
12091209
)
12101210

1211-
def _set_default_apikey_required(self, swagger_editor): # type: ignore[no-untyped-def]
1211+
def _set_default_apikey_required(self, swagger_editor: SwaggerEditor) -> None:
12121212
for path in swagger_editor.iter_on_path():
12131213
swagger_editor.set_path_default_apikey_required(path)
12141214

samtranslator/model/eventsources/push.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -726,9 +726,9 @@ def _add_swagger_integration(self, api, function, intrinsics_resolver): # type:
726726
partition = ArnGenerator.get_partition_name() # type: ignore[no-untyped-call]
727727
uri = _build_apigw_integration_uri(function, partition) # type: ignore[no-untyped-call]
728728

729-
editor = SwaggerEditor(swagger_body) # type: ignore[no-untyped-call]
729+
editor = SwaggerEditor(swagger_body)
730730

731-
if editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined, no-untyped-call]
731+
if editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined]
732732
# Cannot add the Lambda Integration, if it is already present
733733
raise InvalidEventException(
734734
self.relative_id,
@@ -1232,7 +1232,7 @@ def _add_openapi_integration(self, api, function, manage_swagger=False): # type
12321232

12331233
editor = OpenApiEditor(open_api_body)
12341234

1235-
if manage_swagger and editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined, no-untyped-call]
1235+
if manage_swagger and editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined]
12361236
# Cannot add the Lambda Integration, if it is already present
12371237
raise InvalidEventException(
12381238
self.relative_id,

samtranslator/model/stepfunctions/events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,9 @@ def _add_swagger_integration(self, api, resource, role, intrinsics_resolver): #
350350

351351
integration_uri = fnSub("arn:${AWS::Partition}:apigateway:${AWS::Region}:states:action/StartExecution")
352352

353-
editor = SwaggerEditor(swagger_body) # type: ignore[no-untyped-call]
353+
editor = SwaggerEditor(swagger_body)
354354

355-
if editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined, no-untyped-call]
355+
if editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined]
356356
# Cannot add the integration, if it is already present
357357
raise InvalidEventException(
358358
self.relative_id,
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""Base class for OpenApiEditor and SwaggerEditor."""
2+
import re
3+
from typing import Any, Dict, Iterator, List, Optional, Union
4+
5+
from samtranslator.model.apigateway import ApiGatewayAuthorizer
6+
from samtranslator.model.apigatewayv2 import ApiGatewayV2Authorizer
7+
from samtranslator.model.exceptions import InvalidDocumentException, InvalidTemplateException
8+
from samtranslator.model.intrinsics import is_intrinsic_no_value, make_conditional
9+
from samtranslator.utils.py27hash_fix import Py27Dict
10+
11+
12+
class BaseEditor(object):
13+
# constants:
14+
_X_APIGW_INTEGRATION = "x-amazon-apigateway-integration"
15+
_CONDITIONAL_IF = "Fn::If"
16+
_X_ANY_METHOD = "x-amazon-apigateway-any-method"
17+
# https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
18+
_ALL_HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"]
19+
_SERVERS = "servers"
20+
_OPENAPI_VERSION_3_REGEX = r"\A3(\.\d)(\.\d)?$"
21+
22+
# attributes:
23+
_doc: Dict[str, Any]
24+
paths: Dict[str, Any]
25+
26+
@staticmethod
27+
def get_conditional_contents(item: Any) -> List[Any]:
28+
"""
29+
Returns the contents of the given item.
30+
If a conditional block has been used inside the item, returns a list of the content
31+
inside the conditional (both the then and the else cases). Skips {'Ref': 'AWS::NoValue'} content.
32+
If there's no conditional block, then returns an list with the single item in it.
33+
34+
:param dict item: item from which the contents will be extracted
35+
:return: list of item content
36+
"""
37+
contents = [item]
38+
if isinstance(item, dict) and BaseEditor._CONDITIONAL_IF in item:
39+
contents = item[BaseEditor._CONDITIONAL_IF][1:]
40+
contents = [content for content in contents if not is_intrinsic_no_value(content)]
41+
return contents
42+
43+
@staticmethod
44+
def method_definition_has_integration(method_definition: Dict[str, Any]) -> bool:
45+
"""
46+
Checks a method definition to make sure it has an apigw integration
47+
48+
:param dict method_definition: method definition dictionary
49+
:return: True if an integration exists
50+
"""
51+
52+
return bool(method_definition.get(BaseEditor._X_APIGW_INTEGRATION))
53+
54+
def method_has_integration(self, method: Dict[str, Any]) -> bool:
55+
"""
56+
Returns true if the given method contains a valid method definition.
57+
This uses the get_conditional_contents function to handle conditionals.
58+
59+
:param dict method: method dictionary
60+
:return: true if method has one or multiple integrations
61+
"""
62+
for method_definition in self.get_conditional_contents(method):
63+
if self.method_definition_has_integration(method_definition):
64+
return True
65+
return False
66+
67+
def make_path_conditional(self, path: str, condition: str) -> None:
68+
"""
69+
Wrap entire API path definition in a CloudFormation if condition.
70+
:param path: path name
71+
:param condition: condition name
72+
"""
73+
self.paths[path] = make_conditional(condition, self.paths[path])
74+
75+
def iter_on_path(self) -> Iterator[str]:
76+
"""
77+
Yields all the paths available in the Swagger. As a caller, if you add new paths to Swagger while iterating,
78+
they will not show up in this iterator
79+
80+
:yields string: Path name
81+
"""
82+
83+
for path, _ in self.paths.items():
84+
yield path
85+
86+
@staticmethod
87+
def _normalize_method_name(method: Any) -> Any:
88+
"""
89+
Returns a lower case, normalized version of HTTP Method. It also know how to handle API Gateway specific methods
90+
like "ANY"
91+
92+
NOTE: Always normalize before using the `method` value passed in as input
93+
94+
:param string method: Name of the HTTP Method
95+
:return string: Normalized method name
96+
"""
97+
if not method or not isinstance(method, str):
98+
return method
99+
100+
method = method.lower()
101+
if method == "any":
102+
return BaseEditor._X_ANY_METHOD
103+
return method
104+
105+
def has_path(self, path: str, method: Optional[str] = None) -> bool:
106+
"""
107+
Returns True if this Swagger has the given path and optional method
108+
For paths with conditionals, only returns true if both items (true case, and false case) have the method.
109+
110+
:param string path: Path name
111+
:param string method: HTTP method
112+
:return: True, if this path/method is present in the document
113+
"""
114+
if path not in self.paths:
115+
return False
116+
117+
method = self._normalize_method_name(method)
118+
if method:
119+
for path_item in self.get_conditional_contents(self.paths.get(path)):
120+
if not path_item or method not in path_item:
121+
return False
122+
return True
123+
124+
def has_integration(self, path: str, method: str) -> bool:
125+
"""
126+
Checks if an API Gateway integration is already present at the given path/method.
127+
For paths with conditionals, it only returns True if both items (true case, false case) have the integration
128+
129+
:param string path: Path name
130+
:param string method: HTTP method
131+
:return: True, if an API Gateway integration is already present
132+
"""
133+
method = self._normalize_method_name(method)
134+
135+
if not self.has_path(path, method):
136+
return False
137+
138+
for path_item in self.get_conditional_contents(self.paths.get(path)):
139+
method_definition = path_item.get(method)
140+
if not (isinstance(method_definition, dict) and self.method_has_integration(method_definition)):
141+
return False
142+
# Integration present and non-empty
143+
return True
144+
145+
def add_path(self, path: str, method: Optional[str] = None) -> None:
146+
"""
147+
Adds the path/method combination to the Swagger, if not already present
148+
149+
:param string path: Path name
150+
:param string method: HTTP method
151+
:raises InvalidDocumentException: If the value of `path` in Swagger is not a dictionary
152+
"""
153+
method = self._normalize_method_name(method)
154+
155+
path_dict = self.paths.setdefault(path, Py27Dict())
156+
157+
if not isinstance(path_dict, dict):
158+
# Either customers has provided us an invalid Swagger, or this class has messed it somehow
159+
raise InvalidDocumentException(
160+
[InvalidTemplateException(f"Value of '{path}' path must be a dictionary according to Swagger spec.")]
161+
)
162+
163+
for path_item in self.get_conditional_contents(path_dict):
164+
path_item.setdefault(method, Py27Dict())
165+
166+
@staticmethod
167+
def _get_authorization_scopes(
168+
authorizers: Union[Dict[str, ApiGatewayAuthorizer], Dict[str, ApiGatewayV2Authorizer]], default_authorizer: str
169+
) -> Any:
170+
"""
171+
Returns auth scopes for an authorizer if present
172+
:param authorizers: authorizer definitions
173+
:param default_authorizer: name of the default authorizer
174+
"""
175+
authorizer = authorizers.get(default_authorizer)
176+
if authorizer and authorizer.authorization_scopes is not None:
177+
return authorizer.authorization_scopes
178+
return []
179+
180+
def iter_on_method_definitions_for_path_at_method(
181+
self, path_name: str, method_name: str, skip_methods_without_apigw_integration: bool = True
182+
) -> Iterator[Dict[str, Any]]:
183+
"""
184+
Yields all the method definitions for the path+method combinations if path and/or method have IF conditionals.
185+
If there are no conditionals, will just yield the single method definition at the given path and method name.
186+
187+
:param path_name: path name
188+
:param method_name: method name
189+
:param skip_methods_without_apigw_integration: if True, skips method definitions without apigw integration
190+
:yields dict: method definition
191+
"""
192+
normalized_method_name = self._normalize_method_name(method_name)
193+
194+
for path_item in self.get_conditional_contents(self.paths.get(path_name)):
195+
for method_definition in self.get_conditional_contents(path_item.get(normalized_method_name)):
196+
if skip_methods_without_apigw_integration and not self.method_definition_has_integration(
197+
method_definition
198+
):
199+
continue
200+
yield method_definition
201+
202+
@staticmethod
203+
def validate_is_dict(obj: Any, exception_message: str) -> None:
204+
"""
205+
Throws exception if obj is not a dict
206+
207+
:param obj: object being validated
208+
:param exception_message: message to include in exception if obj is not a dict
209+
"""
210+
211+
if not isinstance(obj, dict):
212+
raise InvalidDocumentException([InvalidTemplateException(exception_message)])
213+
214+
@staticmethod
215+
def validate_path_item_is_dict(path_item: Any, path: str) -> None:
216+
"""
217+
Throws exception if path_item is not a dict
218+
219+
:param path_item: path_item (value at the path) being validated
220+
:param path: path name
221+
"""
222+
223+
BaseEditor.validate_is_dict(
224+
path_item, "Value of '{}' path must be a dictionary according to Swagger spec.".format(path)
225+
)
226+
227+
@staticmethod
228+
def safe_compare_regex_with_string(regex: str, data: Any) -> bool:
229+
return re.match(regex, str(data)) is not None

0 commit comments

Comments
 (0)