Skip to content

Commit f20ee01

Browse files
committed
Update s3 bucket check in session_helper.py
Code change is based on commit: 903cb8a
1 parent c08847d commit f20ee01

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

sagemaker-core/src/sagemaker/core/helper/session_helper.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
666666
667667
"""
668668
try:
669-
s3.meta.client.head_bucket(
670-
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
671-
)
669+
if self.default_bucket_prefix:
670+
s3.meta.client.list_objects_v2(
671+
Bucket=bucket_name,
672+
Prefix=self.default_bucket_prefix,
673+
ExpectedBucketOwner=expected_bucket_owner_id,
674+
)
675+
else:
676+
s3.meta.client.head_bucket(
677+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
678+
)
672679
except ClientError as e:
673680
error_code = e.response["Error"]["Code"]
674681
message = e.response["Error"]["Message"]
@@ -699,7 +706,12 @@ def general_bucket_check_if_user_has_permission(
699706
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
700707
"""
701708
try:
702-
s3.meta.client.head_bucket(Bucket=bucket_name)
709+
if self.default_bucket_prefix:
710+
s3.meta.client.list_objects_v2(
711+
Bucket=bucket_name, Prefix=self.default_bucket_prefix
712+
)
713+
else:
714+
s3.meta.client.head_bucket(Bucket=bucket_name)
703715
except ClientError as e:
704716
error_code = e.response["Error"]["Code"]
705717
message = e.response["Error"]["Message"]

sagemaker-core/tests/unit/helper/test_session_helper.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,3 +1319,96 @@ def test_endpoint_not_found(self, mock_boto_session, mock_sagemaker_client):
13191319
)
13201320
session = Session(boto_session=mock_boto_session, sagemaker_client=mock_sagemaker_client)
13211321
assert session.endpoint_in_service_or_not("my-endpoint") is False
1322+
1323+
1324+
class TestBucketCheckWithPrefix:
1325+
"""Test bucket check methods with default_bucket_prefix."""
1326+
1327+
@pytest.fixture
1328+
def session_with_prefix(self, mock_boto_session, mock_sagemaker_client):
1329+
"""Create session with bucket prefix."""
1330+
mock_sts_client = Mock()
1331+
mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"}
1332+
mock_boto_session.client.return_value = mock_sts_client
1333+
1334+
session = Session(
1335+
boto_session=mock_boto_session,
1336+
sagemaker_client=mock_sagemaker_client,
1337+
default_bucket="test-bucket",
1338+
default_bucket_prefix="sample-prefix",
1339+
)
1340+
mock_s3_resource = Mock()
1341+
mock_bucket = Mock()
1342+
mock_bucket.creation_date = None
1343+
mock_s3_resource.Bucket.return_value = mock_bucket
1344+
session.s3_resource = mock_s3_resource
1345+
return session
1346+
1347+
def test_default_bucket_with_prefix_forbidden(self, session_with_prefix, caplog):
1348+
"""Test forbidden error when accessing bucket with prefix."""
1349+
error = ClientError(
1350+
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
1351+
operation_name="ListObjectsV2",
1352+
)
1353+
session_with_prefix.s3_resource.meta.client.list_objects_v2.side_effect = error
1354+
1355+
with pytest.raises(ClientError):
1356+
session_with_prefix.default_bucket()
1357+
1358+
assert "Please try again after adding appropriate access." in caplog.text
1359+
assert session_with_prefix._default_bucket is None
1360+
session_with_prefix.s3_resource.meta.client.list_objects_v2.assert_called_once_with(
1361+
Bucket="test-bucket", Prefix="sample-prefix"
1362+
)
1363+
1364+
def test_expected_bucket_owner_check_with_prefix(self, session_with_prefix):
1365+
"""Test expected bucket owner check uses list_objects_v2 with prefix."""
1366+
session_with_prefix.expected_bucket_owner_id_bucket_check(
1367+
"test-bucket", session_with_prefix.s3_resource, "123456789012"
1368+
)
1369+
session_with_prefix.s3_resource.meta.client.list_objects_v2.assert_called_once_with(
1370+
Bucket="test-bucket", Prefix="sample-prefix", ExpectedBucketOwner="123456789012"
1371+
)
1372+
1373+
def test_expected_bucket_owner_check_without_prefix(self, mock_boto_session, mock_sagemaker_client):
1374+
"""Test expected bucket owner check uses head_bucket without prefix."""
1375+
session = Session(
1376+
boto_session=mock_boto_session,
1377+
sagemaker_client=mock_sagemaker_client,
1378+
default_bucket="test-bucket",
1379+
)
1380+
mock_s3_resource = Mock()
1381+
session.s3_resource = mock_s3_resource
1382+
1383+
session.expected_bucket_owner_id_bucket_check(
1384+
"test-bucket", mock_s3_resource, "123456789012"
1385+
)
1386+
mock_s3_resource.meta.client.head_bucket.assert_called_once_with(
1387+
Bucket="test-bucket", ExpectedBucketOwner="123456789012"
1388+
)
1389+
1390+
def test_general_bucket_check_with_prefix(self, session_with_prefix):
1391+
"""Test general bucket check uses list_objects_v2 with prefix."""
1392+
mock_bucket = Mock()
1393+
session_with_prefix.general_bucket_check_if_user_has_permission(
1394+
"test-bucket", session_with_prefix.s3_resource, mock_bucket, "us-west-2", True
1395+
)
1396+
session_with_prefix.s3_resource.meta.client.list_objects_v2.assert_called_once_with(
1397+
Bucket="test-bucket", Prefix="sample-prefix"
1398+
)
1399+
1400+
def test_general_bucket_check_without_prefix(self, mock_boto_session, mock_sagemaker_client):
1401+
"""Test general bucket check uses head_bucket without prefix."""
1402+
session = Session(
1403+
boto_session=mock_boto_session,
1404+
sagemaker_client=mock_sagemaker_client,
1405+
default_bucket="test-bucket",
1406+
)
1407+
mock_s3_resource = Mock()
1408+
mock_bucket = Mock()
1409+
session.s3_resource = mock_s3_resource
1410+
1411+
session.general_bucket_check_if_user_has_permission(
1412+
"test-bucket", mock_s3_resource, mock_bucket, "us-west-2", True
1413+
)
1414+
mock_s3_resource.meta.client.head_bucket.assert_called_once_with(Bucket="test-bucket")

0 commit comments

Comments
 (0)