Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions transformer_engine/common/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)

# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens = tl.sum(input_split_sizes)

input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
Expand All @@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset

# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row = tl.where(pid < total_valid_tokens, dst_row, pid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to copy the tokens/overallocated space beyond the valid range to the output? Or can we skip those and leave the output uninitialized since they are beyond the valid range and won't be used downstream?

Maybe this could be achieved by using a mask arg to load/store that is dependent on tl.sum(input_split_sizes)?
https://triton-lang.org/main/python-api/generated/triton.language.store.html#triton.language.store


tl.store(dst_rows_ptr + pid, dst_row)


Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);

// Inspect
XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler);

// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

Expand Down
43 changes: 43 additions & 0 deletions transformer_engine/jax/csrc/extensions/amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
************************************************************************/
#include <cuda_runtime.h>

#include <fstream>
#include <iostream>

#include "../extensions.h"
Expand Down Expand Up @@ -96,5 +97,47 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis")); // flatten_axis

Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) {
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");

std::vector<uint8_t> input_data(input_buf.size_bytes());
cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

int device;
cudaGetDevice(&device);

std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
Comment on lines +116 to +120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No error handling if file fails to open - silently continues without writing data

Suggested change
std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
return ffi::Error(ffi::ErrorCode::kInternal, "Failed to open file for writing");
}
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();

printf("Tensor data written to %s (shape: [", filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%ld", static_cast<long>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d)\n", static_cast<int>(input_buf.element_type()));

// TODO: make a metadata file with tensor shape and dtype?

return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
);

} // namespace jax
} // namespace transformer_engine
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pybind11::dict Registrations() {
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));

dict["te_inspect_ffi"] =
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));

return dict;
}

Expand Down
128 changes: 128 additions & 0 deletions transformer_engine/jax/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX array inspection utilities."""

from functools import partial

import jax
import jax.numpy as jnp
from jax import ffi

from .cpp_extensions.base import BasePrimitive, register_primitive

__all__ = ["inspect_array", "load_array_dump"]


class InspectPrimitive(BasePrimitive):
"""
No-op used for inspect array values.
"""

name = "te_inspect_ffi"
multiple_results = False
impl_static_args = ()
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
x_aval,
):
"""
inspect abstract
"""
return x_aval

@staticmethod
def lowering(
ctx,
x,
):
"""
inspect lowering rules
"""

return ffi.ffi_lowering(
InspectPrimitive.name,
operand_output_aliases={0: 0}, # donate input buffer to output buffer
)(
ctx,
x,
)

@staticmethod
def impl(
x,
):
"""
inspect implementation
"""
assert InspectPrimitive.inner_primitive is not None
(x) = InspectPrimitive.inner_primitive.bind(
x,
)
return x


register_primitive(InspectPrimitive)


@partial(jax.custom_vjp, nondiff_argnums=())
def _inspect(
x,
):
""" """
output, _ = _inspect_fwd_rule(
x,
)
return output


def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = InspectPrimitive.outer_primitive.bind(x)
return x, ctx


def _inspect_bwd_rule(
ctx,
grad,
):
""""""
del ctx
return (grad,)


_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)


def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name parameter is unused - not passed to C++ backend or used in filename

Suggested change
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
def inspect_array(x: jnp.ndarray) -> jnp.ndarray:

"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
# TODO: Handle the name of the tensor in the primitive and output files
return _inspect(x)


def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
"""Utility function to load a JAX array from a dumped binary file.

Args:
filename (str): The path to the binary file containing the array data.
shape (tuple): The shape of the array to be loaded.
dtype (jnp.dtype): The data type of the array to be loaded.

Returns:
jnp.ndarray: The loaded JAX array.
"""
with open(filename, "rb") as f:
data = f.read()
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
return array
23 changes: 14 additions & 9 deletions transformer_engine/jax/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def sort_chunks_by_index(
return _sort_chunks_by_index(inp, split_sizes, sorted_indices)


@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
@jax.custom_vjp
def _sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
Expand All @@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
Expand All @@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule(
)

# Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size)
# Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums
residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size)
return (output, row_id_map), residuals


def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray,
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals
row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals
output_grad, _ = g

# Backward: reverse the sort
Expand All @@ -642,7 +641,13 @@ def _sort_chunks_by_index_bwd_rule(
is_forward=False,
)

return (inp_grad,)
# Return gradients for all inputs: (inp, split_sizes, sorted_indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following why the split_sizes and sorted_indices can't be non-differentiable args, can you let me know the reasoning. Thanks!

# split_sizes and sorted_indices are integer arrays, so their gradients are zeros
# with matching dtype (use float32 as a safe default for index arrays)
split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "use float32 as safe default" but code preserves original dtype (likely int32/int64) - comment is misleading

sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype)

return (inp_grad, split_sizes_grad, sorted_indices_grad)


_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)