Skip to content

Commit e043098

Browse files
committed
feat: support pipeline versioning
For commit: aws/sagemaker-python-sdk-staging@9bfe85a
1 parent c4ca393 commit e043098

File tree

2 files changed

+104
-2
lines changed

2 files changed

+104
-2
lines changed

sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ def __init__(
130130
self.sagemaker_session.boto_session.client("scheduler"),
131131
)
132132

133+
@property
134+
def latest_pipeline_version_id(self):
135+
"""Retrieves the latest version id of this pipeline"""
136+
summaries = self.list_pipeline_versions(max_results=1)["PipelineVersionSummaries"]
137+
if not summaries:
138+
return None
139+
else:
140+
return summaries[0].get("PipelineVersionId")
141+
133142
def create(
134143
self,
135144
role_arn: str = None,
@@ -219,15 +228,22 @@ def _create_args(
219228
)
220229
return kwargs
221230

222-
def describe(self) -> Dict[str, Any]:
231+
def describe(self, pipeline_version_id: int = None) -> Dict[str, Any]:
223232
"""Describes a Pipeline in the Workflow service.
224233
234+
Args:
235+
pipeline_version_id (Optional[str]): version ID of the pipeline to describe.
236+
225237
Returns:
226238
Response dict from the service. See `boto3 client documentation
227239
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/\
228240
sagemaker.html#SageMaker.Client.describe_pipeline>`_
229241
"""
230-
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
242+
kwargs = dict(PipelineName=self.name)
243+
if pipeline_version_id:
244+
kwargs["PipelineVersionId"] = pipeline_version_id
245+
246+
return self.sagemaker_session.sagemaker_client.describe_pipeline(**kwargs)
231247

232248
def update(
233249
self,
@@ -337,6 +353,7 @@ def start(
337353
execution_description: str = None,
338354
parallelism_config: ParallelismConfiguration = None,
339355
selective_execution_config: SelectiveExecutionConfig = None,
356+
pipeline_version_id: int = None,
340357
):
341358
"""Starts a Pipeline execution in the Workflow service.
342359
@@ -350,6 +367,8 @@ def start(
350367
over the parallelism configuration of the parent pipeline.
351368
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
352369
selective step execution.
370+
pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not
371+
specified, uses the latest version ID.
353372
354373
Returns:
355374
A `_PipelineExecution` instance, if successful.
@@ -371,6 +390,7 @@ def start(
371390
PipelineExecutionDisplayName=execution_display_name,
372391
ParallelismConfiguration=parallelism_config,
373392
SelectiveExecutionConfig=selective_execution_config,
393+
PipelineVersionId=pipeline_version_id,
374394
)
375395
if self.sagemaker_session.local_mode:
376396
update_args(kwargs, PipelineParameters=parameters)
@@ -466,6 +486,32 @@ def list_executions(
466486
if key in response
467487
}
468488

489+
def list_pipeline_versions(
490+
self, sort_order: str = None, max_results: int = None, next_token: str = None
491+
) -> str:
492+
"""Lists a pipeline's versions.
493+
494+
Args:
495+
sort_order (str): The sort order for results (Ascending/Descending).
496+
max_results (int): The maximum number of pipeline executions to return in the response.
497+
next_token (str): If the result of the previous `ListPipelineExecutions` request was
498+
truncated, the response includes a `NextToken`. To retrieve the next set of pipeline
499+
executions, use the token in the next request.
500+
501+
Returns:
502+
List of Pipeline Version Summaries. See
503+
boto3 client list_pipeline_versions
504+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/list_pipeline_versions.html#
505+
"""
506+
kwargs = dict(PipelineName=self.name)
507+
update_args(
508+
kwargs,
509+
SortOrder=sort_order,
510+
NextToken=next_token,
511+
MaxResults=max_results,
512+
)
513+
return self.sagemaker_session.sagemaker_client.list_pipeline_versions(**kwargs)
514+
469515
def _get_latest_execution_arn(self):
470516
"""Retrieves the latest execution of this pipeline"""
471517
response = self.list_executions(

sagemaker-mlops/tests/unit/workflow/test_pipeline.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,62 @@ def test_pipeline_graph_iteration(mock_step):
417417
assert len(steps) == 1
418418

419419

420+
421+
422+
423+
def test_generate_step_map_duplicate_names():
424+
from sagemaker.mlops.workflow.pipeline import _generate_step_map
425+
426+
step1 = Mock(spec=Step)
427+
step1.name = "duplicate"
428+
step2 = Mock(spec=Step)
429+
step2.name = "duplicate"
430+
431+
step_map = {}
432+
with pytest.raises(ValueError, match="duplicate names"):
433+
_generate_step_map([step1, step2], step_map)
434+
435+
436+
def test_pipeline_latest_version_id(mock_session, mock_step):
437+
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)
438+
mock_session.sagemaker_client.list_pipeline_versions.return_value = {
439+
"PipelineVersionSummaries": [{"PipelineVersionId": 123}]
440+
}
441+
assert pipeline.latest_pipeline_version_id == 123
442+
443+
444+
def test_pipeline_latest_version_id_none(mock_session, mock_step):
445+
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)
446+
mock_session.sagemaker_client.list_pipeline_versions.return_value = {
447+
"PipelineVersionSummaries": []
448+
}
449+
assert pipeline.latest_pipeline_version_id is None
450+
451+
452+
def test_pipeline_describe_with_version_id(mock_session, mock_step):
453+
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)
454+
pipeline.describe(pipeline_version_id=123)
455+
mock_session.sagemaker_client.describe_pipeline.assert_called_once_with(
456+
PipelineName="test-pipeline", PipelineVersionId=123
457+
)
458+
459+
460+
def test_pipeline_start_with_version_id(mock_session, mock_step):
461+
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)
462+
mock_session.sagemaker_client.start_pipeline_execution.return_value = {"PipelineExecutionArn": "arn"}
463+
pipeline.start(pipeline_version_id=123)
464+
call_kwargs = mock_session.sagemaker_client.start_pipeline_execution.call_args[1]
465+
assert call_kwargs["PipelineVersionId"] == 123
466+
467+
468+
def test_pipeline_list_versions(mock_session, mock_step):
469+
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)
470+
pipeline.list_pipeline_versions(sort_order="Descending", max_results=10)
471+
mock_session.sagemaker_client.list_pipeline_versions.assert_called_once_with(
472+
PipelineName="test-pipeline", SortOrder="Descending", MaxResults=10
473+
)
474+
475+
420476
def test_pipeline_execution_result_waiter_error(mock_session):
421477
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
422478
from botocore.exceptions import WaiterError

0 commit comments

Comments
 (0)