Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions sagemaker-mlops/src/sagemaker/mlops/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""

from __future__ import absolute_import

from typing import Any, Dict, List, Union, Optional
Expand All @@ -27,7 +28,12 @@ class EMRStepConfig:
"""Config for a Hadoop Jar step."""

def __init__(
self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None
self,
jar,
args: List[str] = None,
main_class: str = None,
properties: List[dict] = None,
output_args: dict[str, str] = None,
):
"""Create a definition for input data used by an EMR cluster(job flow) step.

Expand All @@ -41,12 +47,24 @@ def __init__(
jar(str): A path to a JAR file run during the step.
main_class(str): The name of the main class in the specified Java file.
properties(List(dict)): A list of key-value pairs that are set when the step runs.
output_args(dict[str, str]):
A dict of argument-value pairs (output_name: S3 URI) that extends the command line
args and can be accessible in other steps via EMRStep.emr_outputs[output_name].
Argument names are prepended by '--' automatically.
Example: {"output-path": "s3://my-bucket/output/"} will result in the following
command line args: ["--output-path", "s3://my-bucket/output/"]
"""
self.jar = jar
self.args = args
self.main_class = main_class
self.properties = properties

self.output_args_index = {}
if output_args:
for output_arg_name, output_arg_value in output_args.items():
self.args.extend([f"--{output_arg_name}", output_arg_value])
self.output_args_index[output_arg_name] = len(self.args) - 1

def to_request(self) -> RequestType:
"""Convert EMRStepConfig object to request dict."""
config = {"HadoopJarStep": {"Jar": self.jar}}
Expand Down Expand Up @@ -230,6 +248,11 @@ def __init__(
self.cache_config = cache_config
self._properties = root_property

self.emr_outputs = {
output_name: self.properties.Config.Args[step_config.output_args_index[output_name]]
for output_name in step_config.output_args_index
}

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call `AddJobFlowSteps`.
Expand All @@ -250,4 +273,4 @@ def to_request(self) -> RequestType:
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict
return request_dict
25 changes: 18 additions & 7 deletions sagemaker-mlops/tests/unit/workflow/test_emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from sagemaker.mlops.workflow.emr_step import EMRStep, EMRStepConfig
from sagemaker.mlops.workflow.steps import StepTypeEnum
from sagemaker.core.workflow.properties import Properties


def test_emr_step_config_init():
Expand All @@ -39,7 +40,7 @@ def test_emr_step_with_cluster_id():
display_name="EMR Step",
description="Test EMR step",
cluster_id="j-123456",
step_config=config
step_config=config,
)
assert step.name == "emr-step"
assert step.step_type == StepTypeEnum.EMR
Expand All @@ -48,17 +49,15 @@ def test_emr_step_with_cluster_id():
def test_emr_step_with_cluster_config():
config = EMRStepConfig(jar="s3://bucket/my.jar")
cluster_config = {
"Instances": {
"InstanceGroups": [{"InstanceType": "m5.xlarge", "InstanceCount": 1}]
}
"Instances": {"InstanceGroups": [{"InstanceType": "m5.xlarge", "InstanceCount": 1}]}
}
step = EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id=None,
step_config=config,
cluster_config=cluster_config
cluster_config=cluster_config,
)
assert step.name == "emr-step"

Expand All @@ -71,7 +70,7 @@ def test_emr_step_without_cluster_id_or_config_raises_error():
display_name="EMR Step",
description="Test EMR step",
cluster_id=None,
step_config=config
step_config=config,
)


Expand All @@ -84,5 +83,17 @@ def test_emr_step_with_both_cluster_id_and_config_raises_error():
description="Test EMR step",
cluster_id="j-123456",
step_config=config,
cluster_config={"Instances": {}}
cluster_config={"Instances": {}},
)

def test_emr_step_with_output_args():
config = EMRStepConfig(jar="s3://bucket/my.jar", args=["arg1"], output_args={"output": "s3://bucket/my/output/path"})
step = EMRStep(
name="emr-step",
display_name="EMR Step",
description="Test EMR step",
cluster_id="j-123456",
step_config=config,
)
assert "output" in step.emr_outputs
assert isinstance(step.emr_outputs["output"], Properties)