Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 77 additions & 37 deletions sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
from __future__ import absolute_import

from typing import Optional, Dict, Any, Union
import os
try:
from botocore.exceptions import ClientError
except Exception:
# Minimal local fallback for environments without botocore (tests/dev).
class ClientError(Exception):
def __init__(self, error_response, operation_name=None):
self.response = error_response
super().__init__(str(error_response))

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.resources import TrainingJob, ModelPackage
Expand Down Expand Up @@ -198,49 +207,80 @@ def _get_checkpoint_uri_from_manifest(self) -> Optional[str]:
import json
from urllib.parse import urlparse
import logging

logger = logging.getLogger(__name__)

if not isinstance(self.model, TrainingJob):
raise ValueError("Model must be a TrainingJob instance for Nova models")

# Step 1: Get S3 model artifacts from training job
s3_artifacts = self.model.model_artifacts.s3_model_artifacts
if not s3_artifacts:
raise ValueError("No S3 model artifacts found in training job")

logger.info(f"S3 artifacts path: {s3_artifacts}")
# Step 2: Construct manifest path (same directory as model artifacts)
# s3://bucket/path/output/model.tar.gz -> s3://bucket/path/output/output/manifest.json
parts = s3_artifacts.rstrip('/').rsplit('/', 1)
manifest_path = parts[0] + '/output/manifest.json'

logger.info(f"Manifest path: {manifest_path}")

# Step 3: Find and read manifest.json
parsed = urlparse(manifest_path)
bucket = parsed.netloc
manifest_key = parsed.path.lstrip('/')

logger.info(f"Looking for manifest at s3://{bucket}/{manifest_key}")

# Parse S3 URI and build candidate manifest keys. The manifest may be located
# beside the model artifact or in an `output/` subdirectory depending on
# how training jobs write artifacts. Try multiple reasonable candidates.
parsed_artifacts = urlparse(s3_artifacts)
bucket = parsed_artifacts.netloc
key = parsed_artifacts.path.lstrip('/')
base_dir = os.path.dirname(key)

candidate_keys = [
f"{base_dir}/manifest.json",
f"{base_dir}/output/manifest.json",
f"{os.path.dirname(base_dir)}/manifest.json",
]

s3_client = self.boto_session.client('s3')
try:
response = s3_client.get_object(Bucket=bucket, Key=manifest_key)
manifest = json.loads(response['Body'].read().decode('utf-8'))
logger.info(f"Manifest content: {manifest}")

# Step 4: Fetch checkpoint_s3_bucket from manifest
checkpoint_uri = manifest.get('checkpoint_s3_bucket')
if not checkpoint_uri:
raise ValueError(f"'checkpoint_s3_bucket' not found in manifest. Available keys: {list(manifest.keys())}")

logger.info(f"Checkpoint URI: {checkpoint_uri}")
return checkpoint_uri
except s3_client.exceptions.NoSuchKey:
raise ValueError(f"manifest.json not found at s3://{bucket}/{manifest_key}")
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse manifest.json: {e}")
except Exception as e:
raise ValueError(f"Error reading manifest.json: {e}")
last_manifest = None
for manifest_key in candidate_keys:
logger.debug(f"Looking for manifest at s3://{bucket}/{manifest_key}")
try:
response = s3_client.get_object(Bucket=bucket, Key=manifest_key)
except ClientError as e:
code = e.response.get('Error', {}).get('Code', '')
# Continue searching if the key/bucket was not found; re-raise unexpected errors
if code in ("NoSuchKey", "NoSuchBucket", "404"):
logger.debug(f"manifest not found at {manifest_key}: {code}")
continue
raise

try:
body = response['Body'].read().decode('utf-8')
manifest = json.loads(body)
last_manifest = manifest
logger.debug(f"Manifest content keys: {list(manifest.keys())}")

# Try to find any manifest entry that looks like a checkpoint URI
checkpoint_uri = None
for k, v in manifest.items():
if 'checkpoint' in k.lower():
checkpoint_uri = v
break

# Fallback to common keys
if not checkpoint_uri:
checkpoint_uri = (
manifest.get('checkpoint_s3_bucket')
or manifest.get('checkpoint_s3_uri')
or manifest.get('checkpoint_s3_path')
)

if not checkpoint_uri:
# If manifest present but no checkpoint-like key, keep searching other candidates
logger.debug(f"No checkpoint key in manifest at {manifest_key}")
continue

logger.info(f"Checkpoint URI: {checkpoint_uri}")
return checkpoint_uri

except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse manifest.json at s3://{bucket}/{manifest_key}: {e}")

# If we exhausted candidates without finding a checkpoint, raise a clear error
tried = ", ".join([f"s3://{bucket}/{k}" for k in candidate_keys])
extra = f" Available manifest keys: {list(last_manifest.keys())}" if last_manifest else ""
raise ValueError(f"manifest.json with checkpoint not found. Tried: {tried}.{extra}")
138 changes: 138 additions & 0 deletions tests/unit/test_bedrock_model_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import io
import json
import sys
import os
from types import SimpleNamespace

import pytest

# Ensure package sources are importable in tests (adds local `src` directories)
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, os.path.join(ROOT, "sagemaker-serve", "src"))
sys.path.insert(0, os.path.join(ROOT, "sagemaker-core", "src"))
import types
import importlib.util

# Provide minimal sagemaker core/train/telemetry stubs to avoid heavy imports
if 'sagemaker.core' not in sys.modules:
# core package
sagemaker_mod = types.ModuleType('sagemaker')
core_mod = types.ModuleType('sagemaker.core')
helper_mod = types.ModuleType('sagemaker.core.helper')
session_helper_mod = types.ModuleType('sagemaker.core.helper.session_helper')
# simple Session stub
class Session:
def __init__(self, boto_session=None):
from types import SimpleNamespace
self.boto_session = SimpleNamespace(client=lambda svc: None)
session_helper_mod.Session = Session
resources_mod = types.ModuleType('sagemaker.core.resources')
class TrainingJob:
pass
class ModelPackage:
@classmethod
def get(cls, arn):
return None
resources_mod.TrainingJob = TrainingJob
resources_mod.ModelPackage = ModelPackage
telemetry_mod = types.ModuleType('sagemaker.core.telemetry')
telemetry_logging_mod = types.ModuleType('sagemaker.core.telemetry.telemetry_logging')
def _telemetry_emitter(feature=None, func_name=None):
def deco(f):
return f
return deco
telemetry_logging_mod._telemetry_emitter = _telemetry_emitter
telemetry_constants_mod = types.ModuleType('sagemaker.core.telemetry.constants')
class Feature:
MODEL_CUSTOMIZATION = 'MODEL_CUSTOMIZATION'
telemetry_constants_mod.Feature = Feature

sys.modules['sagemaker'] = types.ModuleType('sagemaker')
sys.modules['sagemaker.core'] = core_mod
sys.modules['sagemaker.core.helper'] = helper_mod
sys.modules['sagemaker.core.helper.session_helper'] = session_helper_mod
sys.modules['sagemaker.core.resources'] = resources_mod
sys.modules['sagemaker.core.telemetry'] = telemetry_mod
sys.modules['sagemaker.core.telemetry.telemetry_logging'] = telemetry_logging_mod
sys.modules['sagemaker.core.telemetry.constants'] = telemetry_constants_mod
# train module stub
train_mod = types.ModuleType('sagemaker.train')
model_trainer_mod = types.ModuleType('sagemaker.train.model_trainer')
class ModelTrainer:
pass
model_trainer_mod.ModelTrainer = ModelTrainer
sys.modules['sagemaker.train'] = train_mod
sys.modules['sagemaker.train.model_trainer'] = model_trainer_mod

# Provide minimal botocore.exceptions.ClientError for environments without botocore
if 'botocore' not in sys.modules:
botocore_mod = types.ModuleType('botocore')
class _ClientError(Exception):
def __init__(self, error_response, operation_name=None):
self.response = error_response
super().__init__(str(error_response))
exceptions_mod = types.SimpleNamespace(ClientError=_ClientError)
botocore_mod.exceptions = exceptions_mod
sys.modules['botocore'] = botocore_mod
sys.modules['botocore.exceptions'] = exceptions_mod

# Load bedrock_model_builder module directly from source to avoid importing the
# package-level __init__ which pulls in many heavy dependencies.
module_path = os.path.join(ROOT, 'sagemaker-serve', 'src', 'sagemaker', 'serve', 'bedrock_model_builder.py')
spec = importlib.util.spec_from_file_location('bb_module', module_path)
bb_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(bb_module)
BedrockModelBuilder = bb_module.BedrockModelBuilder


class FakeTrainingJob:
def __init__(self, s3_uri: str):
self.model_artifacts = SimpleNamespace(s3_model_artifacts=s3_uri)
self.output_model_package_arn = None


def make_s3_client_success(checkpoint_value: str):
class S3Stub:
def get_object(self, Bucket, Key):
body = io.BytesIO(json.dumps({"checkpoint_s3_bucket": checkpoint_value}).encode('utf-8'))
return {"Body": body}

return S3Stub()


def make_s3_client_not_found():
class S3Stub:
def get_object(self, Bucket, Key):
error_response = {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}
raise bb_module.ClientError(error_response, 'GetObject')

return S3Stub()


def test_get_checkpoint_uri_success(monkeypatch):
s3_uri = 's3://mybucket/path/output/model.tar.gz'
fake_job = FakeTrainingJob(s3_uri)


# Monkeypatch TrainingJob class to allow isinstance check
bb_module.TrainingJob = FakeTrainingJob

builder = BedrockModelBuilder(model=fake_job)
builder.boto_session = SimpleNamespace(client=lambda service: make_s3_client_success('s3://checkpoint-bucket/checkpoint'))

uri = builder._get_checkpoint_uri_from_manifest()
assert uri == 's3://checkpoint-bucket/checkpoint'


def test_get_checkpoint_uri_manifest_not_found(monkeypatch):
s3_uri = 's3://mybucket/path/output/model.tar.gz'
fake_job = FakeTrainingJob(s3_uri)


bb_module.TrainingJob = FakeTrainingJob

builder = BedrockModelBuilder(model=fake_job)
builder.boto_session = SimpleNamespace(client=lambda service: make_s3_client_not_found())

with pytest.raises(ValueError):
builder._get_checkpoint_uri_from_manifest()