diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 4602f41cfd..147742bb05 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel( split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 ).to(tl.int32) input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + + # Compute total valid tokens and skip phantom/padding tokens. + # When the input buffer is larger than sum(split_sizes), tokens beyond + # the valid range should map to themselves (identity mapping) to avoid + # corrupting valid output positions. + total_valid_tokens = tl.sum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) input_chunk_idx = tl.sum(input_split_sizes_mask) input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) @@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel( ).to(tl.int32) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + + # For tokens beyond the valid range (pid >= total_valid_tokens), + # use identity mapping to avoid corrupting valid data + dst_row = tl.where(pid < total_valid_tokens, dst_row, pid) + tl.store(dst_rows_ptr + pid, dst_row) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3fd086e257..1c0bc52b88 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); +// Inspect +XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 5ffccaffb4..e18ee99d81 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include #include "../extensions.h" @@ -96,5 +97,47 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("produce_regular_amax") // produce_regular_amax .Attr("flatten_axis")); // flatten_axis +Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); + NVTE_CHECK(output_buf->untyped_data() != nullptr, + "Output must be provided for inspect operation"); + NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), + "Input and output must point to the same buffer for inspect operation"); + + std::vector input_data(input_buf.size_bytes()); + cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + int device; + cudaGetDevice(&device); + + std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; + std::ofstream file(filename, std::ios::binary); + if (file.is_open()) { + file.write(reinterpret_cast(input_data.data()), input_data.size()); + file.close(); + } + printf("Tensor data written to %s (shape: [", filename.c_str()); + for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { + printf("%ld", static_cast(input_buf.dimensions()[i])); + if (i < input_buf.dimensions().size() - 1) { + printf(", "); + } + } + printf("], dtype: %d)\n", static_cast(input_buf.element_type())); + + // TODO: make a metadata file with tensor shape and dtype? + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output +); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a5986404c9..5a8ee18f09 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,6 +81,9 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + dict["te_inspect_ffi"] = + pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); + return dict; } diff --git a/transformer_engine/jax/inspect.py b/transformer_engine/jax/inspect.py new file mode 100644 index 0000000000..61bbaf8bb0 --- /dev/null +++ b/transformer_engine/jax/inspect.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX array inspection utilities.""" + +from functools import partial + +import jax +import jax.numpy as jnp +from jax import ffi + +from .cpp_extensions.base import BasePrimitive, register_primitive + +__all__ = ["inspect_array", "load_array_dump"] + + +class InspectPrimitive(BasePrimitive): + """ + No-op used for inspect array values. + """ + + name = "te_inspect_ffi" + multiple_results = False + impl_static_args = () + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + ): + """ + inspect abstract + """ + return x_aval + + @staticmethod + def lowering( + ctx, + x, + ): + """ + inspect lowering rules + """ + + return ffi.ffi_lowering( + InspectPrimitive.name, + operand_output_aliases={0: 0}, # donate input buffer to output buffer + )( + ctx, + x, + ) + + @staticmethod + def impl( + x, + ): + """ + inspect implementation + """ + assert InspectPrimitive.inner_primitive is not None + (x) = InspectPrimitive.inner_primitive.bind( + x, + ) + return x + + +register_primitive(InspectPrimitive) + + +@partial(jax.custom_vjp, nondiff_argnums=()) +def _inspect( + x, +): + """ """ + output, _ = _inspect_fwd_rule( + x, + ) + return output + + +def _inspect_fwd_rule( + x, +): + """""" + ctx = () + x = InspectPrimitive.outer_primitive.bind(x) + return x, ctx + + +def _inspect_bwd_rule( + ctx, + grad, +): + """""" + del ctx + return (grad,) + + +_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) + + +def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: + """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. + + Args: + x (jnp.ndarray): The JAX array to inspect. + name (str): The name of the array for identification in the output. + """ + # TODO: Handle the name of the tensor in the primitive and output files + return _inspect(x) + + +def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: + """Utility function to load a JAX array from a dumped binary file. + + Args: + filename (str): The path to the binary file containing the array data. + shape (tuple): The shape of the array to be loaded. + dtype (jnp.dtype): The data type of the array to be loaded. + + Returns: + jnp.ndarray: The loaded JAX array. + """ + with open(filename, "rb") as f: + data = f.read() + array = jnp.frombuffer(data, dtype=dtype).reshape(shape) + return array diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 438511fa55..9b6d0f75e0 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -581,7 +581,7 @@ def sort_chunks_by_index( return _sort_chunks_by_index(inp, split_sizes, sorted_indices) -@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +@jax.custom_vjp def _sort_chunks_by_index( inp: jnp.ndarray, split_sizes: jnp.ndarray, @@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule( inp: jnp.ndarray, split_sizes: jnp.ndarray, sorted_indices: jnp.ndarray, -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]: +) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]: """Forward pass rule for sort_chunks_by_index.""" # Validate input dimensions assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D" @@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule( ) # Return (primals, residuals) - residuals = (row_id_map, num_tokens, hidden_size) + # Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums + residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size) return (output, row_id_map), residuals def _sort_chunks_by_index_bwd_rule( - _split_sizes: jnp.ndarray, - _sorted_indices: jnp.ndarray, - residuals: Tuple[jnp.ndarray, int, int], + residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int], g: Tuple[jnp.ndarray, jnp.ndarray], -) -> Tuple[jnp.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward pass rule for sort_chunks_by_index.""" - row_id_map, num_tokens, hidden_size = residuals + row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals output_grad, _ = g # Backward: reverse the sort @@ -642,7 +641,13 @@ def _sort_chunks_by_index_bwd_rule( is_forward=False, ) - return (inp_grad,) + # Return gradients for all inputs: (inp, split_sizes, sorted_indices) + # split_sizes and sorted_indices are integer arrays, so their gradients are zeros + # with matching dtype (use float32 as a safe default for index arrays) + split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype) + sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype) + + return (inp_grad, split_sizes_grad, sorted_indices_grad) _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)