From f5b568abbdd9520f21bbd3ada35d6011cb7917b9 Mon Sep 17 00:00:00 2001 From: Ashish Gupta Date: Wed, 6 Nov 2024 14:11:12 -0800 Subject: [PATCH 1/2] changes for blackbird - model sharding changes for blackbird - model sharding add more tests fix sharded model flag add optimization validations fix formatting and msging fixing validation bugs add UTs simplify logic update messaging formatting fix UTs add more UTs fix validations update ruleset update formatting update validation logic update bug fixes Disable network isolation if using sharded models. check sharding + network iso pre optimization add more UTs for sharding add more UTs --- src/sagemaker/model.py | 14 + .../serve/builder/jumpstart_builder.py | 23 +- src/sagemaker/serve/builder/model_builder.py | 85 ++- src/sagemaker/serve/utils/optimize_utils.py | 23 +- .../serve/validations/optimization.py | 225 ++++++ tests/unit/sagemaker/model/test_model.py | 50 ++ .../serve/builder/test_model_builder.py | 704 +++++++++++++++++- .../serve/utils/test_optimize_utils.py | 36 +- 8 files changed, 1138 insertions(+), 22 deletions(-) create mode 100644 src/sagemaker/serve/validations/optimization.py diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 340d35b250..83efa57cb8 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -372,6 +372,7 @@ def __init__( self.endpoint_name = None self.inference_component_name = None self._is_compiled_model = False + self._is_sharded_model = False self._compilation_job_name = None self._is_edge_packaged_model = False self.inference_recommender_job_results = None @@ -1599,6 +1600,19 @@ def deploy( if self._base_name is not None: self._base_name = "-".join((self._base_name, compiled_model_suffix)) + if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED: + logging.warning( + "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - " + "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints." + ) + endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED + + if self._is_sharded_model and self._enable_network_isolation: + raise ValueError( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access." + ) + # Support multiple models on same endpoint if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: if endpoint_name: diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 7d6a052023..37a77179cb 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -684,6 +684,7 @@ def _optimize_for_jumpstart( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -705,6 +706,8 @@ def _optimize_for_jumpstart( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -730,8 +733,13 @@ def _optimize_for_jumpstart( pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) # optimization_config can contain configs for both quantization and compilation - optimization_config, quantization_override_env, compilation_override_env = ( - _extract_optimization_config_and_env(quantization_config, compilation_config) + ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) = _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config ) if not optimization_config: @@ -807,11 +815,20 @@ def _optimize_for_jumpstart( { **(quantization_override_env or {}), **(compilation_override_env or {}), + **(sharding_override_env or {}), }, ) if optimization_env_vars: self.pysdk_model.env.update(optimization_env_vars) - if quantization_config or is_compilation: + + if sharding_config and self.pysdk_model._enable_network_isolation: + logger.warning( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access. Setting it to False." + ) + self.pysdk_model._enable_network_isolation = False + + if quantization_config or sharding_config or is_compilation: return create_optimization_job_args return None diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 61af6953a2..cfb50584e0 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -105,6 +105,7 @@ get_huggingface_model_metadata, download_huggingface_model_metadata, ) +from sagemaker.serve.validations.optimization import _validate_optimization_configuration logger = logging.getLogger(__name__) @@ -1120,6 +1121,7 @@ def optimize( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -1143,6 +1145,8 @@ def optimize( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -1171,6 +1175,7 @@ def optimize( quantization_config=quantization_config, compilation_config=compilation_config, speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, env_vars=env_vars, vpc_config=vpc_config, kms_key=kms_key, @@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -1227,6 +1235,26 @@ def _model_builder_optimize_wrapper( Returns: Model: A deployable ``Model`` object. """ + if ( + hasattr(self, "enable_network_isolation") + and self.enable_network_isolation + and sharding_config + ): + raise ValueError( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access." + ) + + # TODO: ideally these dictionaries need to be sagemaker_core shapes + # TODO: for organization, abstract all validation behind this fn + _validate_optimization_configuration( + instance_type=instance_type, + quantization_config=quantization_config, + compilation_config=compilation_config, + sharding_config=sharding_config, + speculative_decoding_config=speculative_decoding_config, + ) + self.is_compiled = compilation_config is not None self.is_quantized = quantization_config is not None self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider( @@ -1236,6 +1264,43 @@ def _model_builder_optimize_wrapper( if self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") + if sharding_config and ( + quantization_config or compilation_config or speculative_decoding_config + ): + raise ValueError( + "Sharding config is mutually exclusive and cannot be combined with any other optimization." + ) + + if sharding_config and ( + quantization_config or compilation_config or speculative_decoding_config + ): + raise ValueError( + ( + "Sharding config is mutually exclusive " + "and cannot be combined with any other optimization." + ) + ) + + if sharding_config: + has_tensor_parallel_degree_in_env_vars = ( + env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars + ) + has_tensor_parallel_degree_in_overrides = ( + sharding_config + and sharding_config.get("OverrideEnvironment") + and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment") + ) + if ( + not has_tensor_parallel_degree_in_env_vars + and not has_tensor_parallel_degree_in_overrides + ): + raise ValueError( + ( + "OPTION_TENSOR_PARALLEL_DEGREE is a required " + "environment variable with sharding config." + ) + ) + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.instance_type = instance_type or self.instance_type self.role_arn = role_arn or self.role_arn @@ -1252,6 +1317,7 @@ def _model_builder_optimize_wrapper( quantization_config=quantization_config, compilation_config=compilation_config, speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, env_vars=env_vars, vpc_config=vpc_config, kms_key=kms_key, @@ -1270,12 +1336,16 @@ def _model_builder_optimize_wrapper( quantization_config=quantization_config, compilation_config=compilation_config, speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, env_vars=env_vars, vpc_config=vpc_config, kms_key=kms_key, max_runtime_in_sec=max_runtime_in_sec, ) + if sharding_config: + self.pysdk_model._is_sharded_model = True + if input_args: optimization_instance_type = input_args["DeploymentInstanceType"] @@ -1325,6 +1395,7 @@ def _optimize_for_hf( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -1340,6 +1411,8 @@ def _optimize_for_hf( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -1363,7 +1436,7 @@ def _optimize_for_hf( self.pysdk_model, speculative_decoding_config, False ) - if quantization_config or compilation_config: + if quantization_config or compilation_config or sharding_config: create_optimization_job_args = { "OptimizationJobName": job_name, "DeploymentInstanceType": self.instance_type, @@ -1378,9 +1451,12 @@ def _optimize_for_hf( model_source = _generate_model_source(self.pysdk_model.model_data, False) create_optimization_job_args["ModelSource"] = model_source - optimization_config, quantization_override_env, compilation_override_env = ( - _extract_optimization_config_and_env(quantization_config, compilation_config) - ) + ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) = _extract_optimization_config_and_env(quantization_config, compilation_config) create_optimization_job_args["OptimizationConfigs"] = [ {k: v} for k, v in optimization_config.items() ] @@ -1388,6 +1464,7 @@ def _optimize_for_hf( { **(quantization_override_env or {}), **(compilation_override_env or {}), + **(sharding_override_env or {}), } ) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 14df6b3639..12676c1432 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -361,16 +361,19 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool: def _extract_optimization_config_and_env( - quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None -) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, +) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]: """Extracts optimization config and environment variables. Args: quantization_config (Optional[Dict]): The quantization config. compilation_config (Optional[Dict]): The compilation config. + sharding_config (Optional[Dict]): The sharding config. Returns: - Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: + Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]: The optimization config and environment variables. """ optimization_config = {} @@ -380,6 +383,7 @@ def _extract_optimization_config_and_env( compilation_override_env = ( compilation_config.get("OverrideEnvironment") if compilation_config else None ) + sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None if quantization_config is not None: optimization_config["ModelQuantizationConfig"] = quantization_config @@ -387,12 +391,19 @@ def _extract_optimization_config_and_env( if compilation_config is not None: optimization_config["ModelCompilationConfig"] = compilation_config + if sharding_config is not None: + optimization_config["ModelShardingConfig"] = sharding_config + # Return optimization config dict and environment variables if either is present if optimization_config: - return optimization_config, quantization_override_env, compilation_override_env - - return None, None, None + return ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) + return None, None, None, None def _custom_speculative_decoding( model: Model, diff --git a/src/sagemaker/serve/validations/optimization.py b/src/sagemaker/serve/validations/optimization.py new file mode 100644 index 0000000000..379b7caa92 --- /dev/null +++ b/src/sagemaker/serve/validations/optimization.py @@ -0,0 +1,225 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the validation logic used for the .optimize() function. INTERNAL only""" +from __future__ import absolute_import + +import textwrap +import logging +from typing import Any, Dict, Set, Optional +from enum import Enum +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class _OptimizationContainer(Enum): + """Optimization containers""" + + TRT = "TRT" + VLLM = "vLLM" + NEURON = "Neuron" + + +class _OptimizationCombination(BaseModel): + """Optimization ruleset data structure for comparing input to ruleset""" + + optimization_container: _OptimizationContainer = None + compilation: Set[Optional[bool]] + speculative_decoding: Set[Optional[bool]] + sharding: Set[Optional[bool]] + quantization_technique: Set[Optional[str]] + + def validate_against(self, optimization_combination, rule_set: _OptimizationContainer): + """Validator for optimization containers""" + + # check the validity of each individual field + if not optimization_combination.compilation.issubset(self.compilation): + raise ValueError("Compilation") + if not optimization_combination.quantization_technique.issubset( + self.quantization_technique + ): + copy_quantization_technique = optimization_combination.quantization_technique.copy() + raise ValueError(f"Quantization:{copy_quantization_technique.pop()}") + if not optimization_combination.speculative_decoding.issubset(self.speculative_decoding): + raise ValueError("Speculative Decoding") + if not optimization_combination.sharding.issubset(self.sharding): + raise ValueError("Sharding") + + # optimization technique combinations that need to be validated + if optimization_combination.compilation and optimization_combination.speculative_decoding: + is_compiled = optimization_combination.compilation.copy().pop() + is_speculative_decoding = optimization_combination.speculative_decoding.copy().pop() + if is_compiled and is_speculative_decoding: + raise ValueError("Compilation and Speculative Decoding together") + + if rule_set == _OptimizationContainer.TRT: + is_compiled = optimization_combination.compilation.copy().pop() + is_quantized = optimization_combination.quantization_technique.copy().pop() + if is_quantized and not is_compiled: + raise ValueError(f"Quantization:{is_quantized} must be provided with Compilation") + + +TRUTHY_SET = {None, True} +FALSY_SET = {None, False} +TRT_CONFIGURATION = { + "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.TRT, + compilation=TRUTHY_SET, + quantization_technique={None, "awq", "fp8", "smoothquant"}, + speculative_decoding=FALSY_SET, + sharding=FALSY_SET, + ), +} +VLLM_CONFIGURATION = { + "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.VLLM, + compilation=FALSY_SET, + quantization_technique={None, "awq", "fp8"}, + speculative_decoding=TRUTHY_SET, + sharding=TRUTHY_SET, + ), +} +NEURON_CONFIGURATION = { + "supported_instance_families": {"inf2", "trn1", "trn1n"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.NEURON, + compilation=TRUTHY_SET, + quantization_technique={None}, + speculative_decoding=FALSY_SET, + sharding=FALSY_SET, + ), +} + + +def _validate_optimization_configuration( + instance_type: str, + quantization_config: Dict[str, Any], + compilation_config: Dict[str, Any], + sharding_config: Dict[str, Any], + speculative_decoding_config: Dict[str, Any], +): + """Validate .optimize() input off of standard ruleset""" + + instance_family = None + if instance_type: + split_instance_type = instance_type.split(".") + if len(split_instance_type) == 3: + instance_family = split_instance_type[1] + + if ( + instance_family not in TRT_CONFIGURATION["supported_instance_families"] + and instance_family not in VLLM_CONFIGURATION["supported_instance_families"] + and instance_family not in NEURON_CONFIGURATION["supported_instance_families"] + ): + invalid_instance_type_msg = ( + f"Optimizations that uses {instance_type} instance type are " + "not currently supported both on GPU and Neuron instances" + ) + raise ValueError(invalid_instance_type_msg) + + quantization_technique = None + if ( + quantization_config + and quantization_config.get("OverrideEnvironment") + and quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE") + ): + quantization_technique = quantization_config.get("OverrideEnvironment").get( + "OPTION_QUANTIZE" + ) + + optimization_combination = _OptimizationCombination( + compilation={None if compilation_config is None else True}, + speculative_decoding={None if speculative_decoding_config is None else True}, + sharding={None if sharding_config is None else True}, + quantization_technique={quantization_technique}, + ) + + # Check the case where no optimization combination is provided + if ( + optimization_combination.compilation == {None} + and optimization_combination.quantization_technique == {None} + and optimization_combination.speculative_decoding == {None} + and optimization_combination.sharding == {None} + ): + raise ValueError( + ( + "Optimizations that provide no optimization configs " + "are currently not support on both GPU and Neuron instances." + ) + ) + + # Validate based off of instance type + if instance_family in NEURON_CONFIGURATION["supported_instance_families"]: + try: + ( + NEURON_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.NEURON + ) + ) + except ValueError as neuron_compare_error: + raise ValueError( + ( + f"Optimizations that use {neuron_compare_error} " + "are not supported on Neuron instances." + ) + ) + else: + if optimization_combination.compilation.copy().pop(): # Compilation is only enabled for TRT + try: + TRT_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.TRT + ) + except ValueError as trt_compare_error: + raise ValueError( + ( + f"Optimizations that use Compilation and {trt_compare_error} " + "are not supported for GPU instances." + ) + ) + else: + try: + ( + VLLM_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.VLLM + ) + ) + except ValueError as vllm_compare_error: + try: # try both VLLM and TRT to cover both rule sets + ( + TRT_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.TRT + ) + ) + except ValueError as trt_compare_error: + if ( + str(trt_compare_error) + == "Quantization:smoothquant must be provided with Compilation" + ): + raise ValueError( + f"Optimizations that use {trt_compare_error} for GPU instances." + ) + if str(trt_compare_error) == str(vllm_compare_error): + raise ValueError( + ( + f"Optimizations that use {trt_compare_error} " + "are not supported for GPU instances." + ) + ) + joint_error_msg = f""" + Optimization cannot be performed for the following reasons: + - Optimizations that use {trt_compare_error} are not supported for GPU instances. + - Optimizations that use {vllm_compare_error} are not supported for GPU instances. + """ + raise ValueError(textwrap.dedent(joint_error_msg)) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index e43ad0ed0a..316df7420d 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -959,6 +959,56 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path( sagemaker_session.create_model.reset_mock() +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.fw_utils.tar_and_upload_dir") +def test_sharded_model_force_inference_component_based_endpoint_deploy_path( + repack_model, tar_and_uload_dir, sagemaker_session +): + framework_model_classes_to_kwargs = { + HuggingFaceModel: { + "pytorch_version": "1.7.1", + "py_version": "py36", + "transformers_version": "4.6.1", + }, + } + + sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False) + + source_dir = "s3://blah/blah/blah" + for framework_model_class, kwargs in framework_model_classes_to_kwargs.items(): + test_sharded_model = framework_model_class( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + model_data=source_dir, + **kwargs, + ) + test_sharded_model._is_sharded_model = True + test_sharded_model.deploy( + instance_type="ml.m2.xlarge", + initial_instance_count=INSTANCE_COUNT, + endpoint_type=EndpointType.MODEL_BASED, + resources=ResourceRequirements( + requests={ + "num_accelerators": 1, + "memory": 8192, + "copies": 1, + }, + limits={}, + ), + ) + + # Verified inference component based endpoint and inference component creation + # path + sagemaker_session.endpoint_in_service_or_not.assert_called_once() + sagemaker_session.create_model.assert_called_once() + sagemaker_session.create_inference_component.assert_called_once() + + sagemaker_session.create_inference_component.reset_mock() + sagemaker_session.endpoint_in_service_or_not.reset_mock() + sagemaker_session.create_model.reset_mock() + + @patch("sagemaker.utils.repack_model") def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session): diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 2da09aece3..ab76d1bf99 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -11,12 +11,14 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import MagicMock, patch, Mock, mock_open + +from unittest.mock import MagicMock, patch, Mock, mock_open, ANY import unittest from pathlib import Path from copy import deepcopy +from sagemaker.model import Model from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode @@ -25,6 +27,7 @@ from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.validations.optimization import _validate_optimization_configuration from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG schema_builder = MagicMock() @@ -2383,11 +2386,11 @@ def test_optimize( builder.pysdk_model = pysdk_model job_name = "my-optimization-job" - instance_type = "ml.inf1.xlarge" + instance_type = "ml.g5.24xlarge" output_path = "s3://my-bucket/output" quantization_config = { "Image": "quantization-image-uri", - "OverrideEnvironment": {"ENV_VAR": "value"}, + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, } env_vars = {"Var1": "value", "Var2": "value"} kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" @@ -2425,7 +2428,7 @@ def test_optimize( mock_send_telemetry.assert_called_once() mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", - DeploymentInstanceType="ml.inf1.xlarge", + DeploymentInstanceType="ml.g5.24xlarge", RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", OptimizationEnvironment={"Var1": "value", "Var2": "value"}, ModelSource={"S3": {"S3Uri": "s3://uri"}}, @@ -2433,7 +2436,7 @@ def test_optimize( { "ModelQuantizationConfig": { "Image": "quantization-image-uri", - "OverrideEnvironment": {"ENV_VAR": "value"}, + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, } } ], @@ -2646,7 +2649,8 @@ def test_optimize_local_mode(self, mock_get_serve_setting): ValueError, "Model optimization is only supported in Sagemaker Endpoint Mode.", lambda: model_builder.optimize( - quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}} + instance_type="ml.g5.24xlarge", + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, ), ) @@ -2721,6 +2725,42 @@ def test_optimize_for_hf_with_both_quantization_and_compilation( }, ) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_sharding(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Compilation and Sharding are not supported for GPU instances.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + @patch.object(ModelBuilder, "_prepare_for_mode") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) def test_optimize_for_hf_with_custom_s3_path( @@ -2946,3 +2986,655 @@ def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( output_path="s3://bucket/code/", ), ) + + +class TestModelBuilderOptimizationSharding(unittest.TestCase): + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_env_vars( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"key": "value"} + env_vars = {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[{"ModelShardingConfig": {"key": "value"}}], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_override_and_env_var( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_override( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"Var1": "value", "Var2": "value"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + # squeeze in some validations + with self.assertRaises(ValueError): + builder.enable_network_isolation = True + builder.optimize(sharding_config={}) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_jumpstart") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=True) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", return_value=False + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._find_compatible_deployment_config", + return_value=Mock(), + ) + def test_optimize_sharding_with_override_for_js( + self, + mock_find_compatible_deployment_config, + mock_is_gated_model, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_jumpstart, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model._enable_network_isolation = True + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + model = builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + ModelSource={"S3": {"S3Uri": ANY}}, + DeploymentInstanceType="ml.g5.24xlarge", + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={ + "key": "val", + "Var1": "value", + "Var2": "value", + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + assert not model._enable_network_isolation + + def test_model_sharding_with_eni_fails(self): + test_model = Model(role="mock role") + test_model._is_sharded_model = True + test_model._enable_network_isolation = True + self.assertRaisesRegex( + ValueError, + ( + "EnableNetworkIsolation cannot be set to True since " + "SageMaker Fast Model Loading of model requires network access." + ), + lambda: test_model.deploy(initial_instance_count=1, instance_type="ml.g5.24xlarge"), + ) + + +class TestModelBuilderOptimizeValidations(unittest.TestCase): + + def test_corner_cases_throw_errors(self): + self.assertRaisesRegex( + ValueError, + "Optimizations that uses None instance type are not currently supported", + lambda: _validate_optimization_configuration( + sharding_config={"key": "value"}, + instance_type=None, + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + self.assertRaisesRegex( + ValueError, + ( + "Optimizations that provide no optimization configs " + "are currently not support on both GPU and Neuron instances." + ), + lambda: _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config=None, + ), + ) + + def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self): + # Quantization:smoothquant without compilation + self.assertRaisesRegex( + ValueError, + "Optimizations that use Quantization:smoothquant must be provided with Compilation for GPU instances.", + lambda: _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + # Invalid quantization technique + self.assertRaisesRegex( + ValueError, + "Optimizations that use Quantization:test are not supported for GPU instances.", + lambda: _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "test"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + def test_neuron_configurations_throw_errors_for_rule_set(self): + self.assertRaisesRegex( + ValueError, + "Optimizations that use Speculative Decoding are not supported on Neuron instances.", + lambda: _validate_optimization_configuration( + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config={"key": "value"}, + compilation_config=None, + sharding_config=None, + ), + ) + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Sharding are not supported on Neuron instances.", + lambda: _validate_optimization_configuration( + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config={"key": "value"}, + ), + ) + + def test_trt_configurations_rule_set(self): + # Can be compiled with quantization + _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ), + + # Can be just compiled + _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ) + + # Can be just compiled with empty dict + _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={}, + ) + + def test_vllm_configurations_rule_set(self): + # Can use speculative decoding + _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config={"key": "value"}, + compilation_config=None, + ) + + # Can be quantized + _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ) + + # Can be sharded + _validate_optimization_configuration( + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config={"key": "value"}, + speculative_decoding_config=None, + compilation_config=None, + ) + + def test_neuron_configurations_rule_set(self): + # Can be compiled + _validate_optimization_configuration( + instance_type="ml.inf2.xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ) + + # Can be compiled with empty dict + _validate_optimization_configuration( + instance_type="ml.inf2.xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={}, + ) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 7cf0406f42..2dbd415eee 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -284,7 +284,7 @@ def test_is_draft_model_gated(draft_model_config, expected): @pytest.mark.parametrize( - "quantization_config, compilation_config, expected_config, expected_quant_env, expected_compilation_env", + "quantization_config, compilation_config, sharding_config, expected_config, expected_quant_env, expected_compilation_env, expected_sharding_env", [ ( None, @@ -293,6 +293,7 @@ def test_is_draft_model_gated(draft_model_config, expected): "OPTION_TENSOR_PARALLEL_DEGREE": "2", } }, + None, { "ModelCompilationConfig": { "OverrideEnvironment": { @@ -304,6 +305,7 @@ def test_is_draft_model_gated(draft_model_config, expected): { "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, + None, ), ( { @@ -312,6 +314,7 @@ def test_is_draft_model_gated(draft_model_config, expected): } }, None, + None, { "ModelQuantizationConfig": { "OverrideEnvironment": { @@ -323,21 +326,48 @@ def test_is_draft_model_gated(draft_model_config, expected): "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, None, + None, + ), + ( + None, + None, + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + { + "ModelShardingConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + None, + None, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, ), - (None, None, None, None, None), + (None, None, None, None, None, None, None), ], ) def test_extract_optimization_config_and_env( quantization_config, compilation_config, + sharding_config, expected_config, expected_quant_env, expected_compilation_env, + expected_sharding_env, ): - assert _extract_optimization_config_and_env(quantization_config, compilation_config) == ( + assert _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config + ) == ( expected_config, expected_quant_env, expected_compilation_env, + expected_sharding_env, ) From 0707798c3d023836346adf630de4ca243b47f6c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Tue, 19 Nov 2024 06:28:23 +0000 Subject: [PATCH 2/2] fix rebase issues --- src/sagemaker/serve/builder/model_builder.py | 12 +++------ src/sagemaker/serve/utils/optimize_utils.py | 1 + .../serve/validations/optimization.py | 4 +++ .../serve/builder/test_model_builder.py | 27 ++++++++++++++++++- .../serve/utils/test_optimize_utils.py | 5 +++- 5 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index cfb50584e0..6a3b093ac5 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1248,6 +1248,7 @@ def _model_builder_optimize_wrapper( # TODO: ideally these dictionaries need to be sagemaker_core shapes # TODO: for organization, abstract all validation behind this fn _validate_optimization_configuration( + is_jumpstart=self._is_jumpstart_model_id(), instance_type=instance_type, quantization_config=quantization_config, compilation_config=compilation_config, @@ -1264,13 +1265,6 @@ def _model_builder_optimize_wrapper( if self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") - if sharding_config and ( - quantization_config or compilation_config or speculative_decoding_config - ): - raise ValueError( - "Sharding config is mutually exclusive and cannot be combined with any other optimization." - ) - if sharding_config and ( quantization_config or compilation_config or speculative_decoding_config ): @@ -1456,7 +1450,9 @@ def _optimize_for_hf( quantization_override_env, compilation_override_env, sharding_override_env, - ) = _extract_optimization_config_and_env(quantization_config, compilation_config) + ) = _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config + ) create_optimization_job_args["OptimizationConfigs"] = [ {k: v} for k, v in optimization_config.items() ] diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 12676c1432..68ed1e846d 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -405,6 +405,7 @@ def _extract_optimization_config_and_env( return None, None, None, None + def _custom_speculative_decoding( model: Model, speculative_decoding_config: Optional[Dict], diff --git a/src/sagemaker/serve/validations/optimization.py b/src/sagemaker/serve/validations/optimization.py index 379b7caa92..58ef167039 100644 --- a/src/sagemaker/serve/validations/optimization.py +++ b/src/sagemaker/serve/validations/optimization.py @@ -104,6 +104,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont def _validate_optimization_configuration( + is_jumpstart: bool, instance_type: str, quantization_config: Dict[str, Any], compilation_config: Dict[str, Any], @@ -153,6 +154,9 @@ def _validate_optimization_configuration( and optimization_combination.speculative_decoding == {None} and optimization_combination.sharding == {None} ): + # JumpStart has defaults for Inf/Trn instances + if is_jumpstart and instance_family in NEURON_CONFIGURATION["supported_instance_families"]: + return raise ValueError( ( "Optimizations that provide no optimization configs " diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index ab76d1bf99..4e34c5f864 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -2927,6 +2927,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( "Compilation is not supported for Llama-3.1 with a GPU instance.", lambda: model_builder.optimize( job_name="job_name-123", + instance_type="ml.g5.24xlarge", compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, output_path="s3://bucket/code/", ), @@ -2975,9 +2976,10 @@ def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( self.assertRaisesRegex( ValueError, - "Compilation is not supported with speculative decoding with a GPU instance.", + "Optimizations that use Compilation and Speculative Decoding are not supported for GPU instances.", lambda: model_builder.optimize( job_name="job_name-123", + instance_type="ml.g5.24xlarge", speculative_decoding_config={ "ModelProvider": "custom", "ModelSource": "s3://data-source", @@ -3481,6 +3483,7 @@ def test_corner_cases_throw_errors(self): ValueError, "Optimizations that uses None instance type are not currently supported", lambda: _validate_optimization_configuration( + is_jumpstart=False, sharding_config={"key": "value"}, instance_type=None, quantization_config=None, @@ -3496,6 +3499,7 @@ def test_corner_cases_throw_errors(self): "are currently not support on both GPU and Neuron instances." ), lambda: _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config=None, speculative_decoding_config=None, @@ -3504,12 +3508,22 @@ def test_corner_cases_throw_errors(self): ), ) + _validate_optimization_configuration( + is_jumpstart=True, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config=None, + ) + def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self): # Quantization:smoothquant without compilation self.assertRaisesRegex( ValueError, "Optimizations that use Quantization:smoothquant must be provided with Compilation for GPU instances.", lambda: _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config={ "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, @@ -3525,6 +3539,7 @@ def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self): ValueError, "Optimizations that use Quantization:test are not supported for GPU instances.", lambda: _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config={ "OverrideEnvironment": {"OPTION_QUANTIZE": "test"}, @@ -3540,6 +3555,7 @@ def test_neuron_configurations_throw_errors_for_rule_set(self): ValueError, "Optimizations that use Speculative Decoding are not supported on Neuron instances.", lambda: _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.inf2.xlarge", quantization_config=None, speculative_decoding_config={"key": "value"}, @@ -3552,6 +3568,7 @@ def test_neuron_configurations_throw_errors_for_rule_set(self): ValueError, "Optimizations that use Sharding are not supported on Neuron instances.", lambda: _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.inf2.xlarge", quantization_config=None, speculative_decoding_config=None, @@ -3563,6 +3580,7 @@ def test_neuron_configurations_throw_errors_for_rule_set(self): def test_trt_configurations_rule_set(self): # Can be compiled with quantization _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config={ "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, @@ -3574,6 +3592,7 @@ def test_trt_configurations_rule_set(self): # Can be just compiled _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config=None, sharding_config=None, @@ -3583,6 +3602,7 @@ def test_trt_configurations_rule_set(self): # Can be just compiled with empty dict _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config=None, sharding_config=None, @@ -3593,6 +3613,7 @@ def test_trt_configurations_rule_set(self): def test_vllm_configurations_rule_set(self): # Can use speculative decoding _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config=None, sharding_config=None, @@ -3602,6 +3623,7 @@ def test_vllm_configurations_rule_set(self): # Can be quantized _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config={ "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, @@ -3613,6 +3635,7 @@ def test_vllm_configurations_rule_set(self): # Can be sharded _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.g5.24xlarge", quantization_config=None, sharding_config={"key": "value"}, @@ -3623,6 +3646,7 @@ def test_vllm_configurations_rule_set(self): def test_neuron_configurations_rule_set(self): # Can be compiled _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.inf2.xlarge", quantization_config=None, sharding_config=None, @@ -3632,6 +3656,7 @@ def test_neuron_configurations_rule_set(self): # Can be compiled with empty dict _validate_optimization_configuration( + is_jumpstart=False, instance_type="ml.inf2.xlarge", quantization_config=None, sharding_config=None, diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 2dbd415eee..b392b255da 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -284,7 +284,10 @@ def test_is_draft_model_gated(draft_model_config, expected): @pytest.mark.parametrize( - "quantization_config, compilation_config, sharding_config, expected_config, expected_quant_env, expected_compilation_env, expected_sharding_env", + ( + "quantization_config, compilation_config, sharding_config, expected_config, " + "expected_quant_env, expected_compilation_env, expected_sharding_env" + ), [ ( None,