-
Notifications
You must be signed in to change notification settings - Fork 635
[JAX] TE Permutation integration to Maxtext #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6d33449
3e121b5
a76763a
1059c2b
738cce9
5d8d06c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |||||||||||||||||||||||
| ************************************************************************/ | ||||||||||||||||||||||||
| #include <cuda_runtime.h> | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include <fstream> | ||||||||||||||||||||||||
| #include <iostream> | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include "../extensions.h" | ||||||||||||||||||||||||
|
|
@@ -96,5 +97,47 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( | |||||||||||||||||||||||
| .Attr<bool>("produce_regular_amax") // produce_regular_amax | ||||||||||||||||||||||||
| .Attr<int64_t>("flatten_axis")); // flatten_axis | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf) { | ||||||||||||||||||||||||
| NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); | ||||||||||||||||||||||||
| NVTE_CHECK(output_buf->untyped_data() != nullptr, | ||||||||||||||||||||||||
| "Output must be provided for inspect operation"); | ||||||||||||||||||||||||
| NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), | ||||||||||||||||||||||||
| "Input and output must point to the same buffer for inspect operation"); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| std::vector<uint8_t> input_data(input_buf.size_bytes()); | ||||||||||||||||||||||||
| cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), | ||||||||||||||||||||||||
| cudaMemcpyDeviceToHost, stream); | ||||||||||||||||||||||||
| cudaStreamSynchronize(stream); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int device; | ||||||||||||||||||||||||
| cudaGetDevice(&device); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; | ||||||||||||||||||||||||
| std::ofstream file(filename, std::ios::binary); | ||||||||||||||||||||||||
| if (file.is_open()) { | ||||||||||||||||||||||||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||||||||||||||||||||||||
| file.close(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
Comment on lines
+116
to
+120
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No error handling if file fails to open - silently continues without writing data
Suggested change
|
||||||||||||||||||||||||
| printf("Tensor data written to %s (shape: [", filename.c_str()); | ||||||||||||||||||||||||
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | ||||||||||||||||||||||||
| printf("%ld", static_cast<long>(input_buf.dimensions()[i])); | ||||||||||||||||||||||||
| if (i < input_buf.dimensions().size() - 1) { | ||||||||||||||||||||||||
| printf(", "); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| printf("], dtype: %d)\n", static_cast<int>(input_buf.element_type())); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // TODO: make a metadata file with tensor shape and dtype? | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return ffi_with_cuda_error_check(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, | ||||||||||||||||||||||||
| FFI::Bind() | ||||||||||||||||||||||||
| .Ctx<FFI_Stream_Type>() // stream | ||||||||||||||||||||||||
| .Arg<Buffer_Type>() // input | ||||||||||||||||||||||||
| .Ret<Buffer_Type>() // output | ||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| } // namespace jax | ||||||||||||||||||||||||
| } // namespace transformer_engine | ||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,128 @@ | ||||||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
| # | ||||||
| # See LICENSE for license information. | ||||||
| """JAX array inspection utilities.""" | ||||||
|
|
||||||
| from functools import partial | ||||||
|
|
||||||
| import jax | ||||||
| import jax.numpy as jnp | ||||||
| from jax import ffi | ||||||
|
|
||||||
| from .cpp_extensions.base import BasePrimitive, register_primitive | ||||||
|
|
||||||
| __all__ = ["inspect_array", "load_array_dump"] | ||||||
|
|
||||||
|
|
||||||
| class InspectPrimitive(BasePrimitive): | ||||||
| """ | ||||||
| No-op used for inspect array values. | ||||||
| """ | ||||||
|
|
||||||
| name = "te_inspect_ffi" | ||||||
| multiple_results = False | ||||||
| impl_static_args = () | ||||||
| inner_primitive = None | ||||||
| outer_primitive = None | ||||||
|
|
||||||
| @staticmethod | ||||||
| def abstract( | ||||||
| x_aval, | ||||||
| ): | ||||||
| """ | ||||||
| inspect abstract | ||||||
| """ | ||||||
| return x_aval | ||||||
|
|
||||||
| @staticmethod | ||||||
| def lowering( | ||||||
| ctx, | ||||||
| x, | ||||||
| ): | ||||||
| """ | ||||||
| inspect lowering rules | ||||||
| """ | ||||||
|
|
||||||
| return ffi.ffi_lowering( | ||||||
| InspectPrimitive.name, | ||||||
| operand_output_aliases={0: 0}, # donate input buffer to output buffer | ||||||
| )( | ||||||
| ctx, | ||||||
| x, | ||||||
| ) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def impl( | ||||||
| x, | ||||||
| ): | ||||||
| """ | ||||||
| inspect implementation | ||||||
| """ | ||||||
| assert InspectPrimitive.inner_primitive is not None | ||||||
| (x) = InspectPrimitive.inner_primitive.bind( | ||||||
| x, | ||||||
| ) | ||||||
| return x | ||||||
|
|
||||||
|
|
||||||
| register_primitive(InspectPrimitive) | ||||||
|
|
||||||
|
|
||||||
| @partial(jax.custom_vjp, nondiff_argnums=()) | ||||||
| def _inspect( | ||||||
| x, | ||||||
| ): | ||||||
| """ """ | ||||||
| output, _ = _inspect_fwd_rule( | ||||||
| x, | ||||||
| ) | ||||||
| return output | ||||||
|
|
||||||
|
|
||||||
| def _inspect_fwd_rule( | ||||||
| x, | ||||||
| ): | ||||||
| """""" | ||||||
| ctx = () | ||||||
| x = InspectPrimitive.outer_primitive.bind(x) | ||||||
| return x, ctx | ||||||
|
|
||||||
|
|
||||||
| def _inspect_bwd_rule( | ||||||
| ctx, | ||||||
| grad, | ||||||
| ): | ||||||
| """""" | ||||||
| del ctx | ||||||
| return (grad,) | ||||||
|
|
||||||
|
|
||||||
| _inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) | ||||||
|
|
||||||
|
|
||||||
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. | ||||||
|
|
||||||
| Args: | ||||||
| x (jnp.ndarray): The JAX array to inspect. | ||||||
| name (str): The name of the array for identification in the output. | ||||||
| """ | ||||||
| # TODO: Handle the name of the tensor in the primitive and output files | ||||||
| return _inspect(x) | ||||||
|
|
||||||
|
|
||||||
| def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: | ||||||
| """Utility function to load a JAX array from a dumped binary file. | ||||||
|
|
||||||
| Args: | ||||||
| filename (str): The path to the binary file containing the array data. | ||||||
| shape (tuple): The shape of the array to be loaded. | ||||||
| dtype (jnp.dtype): The data type of the array to be loaded. | ||||||
|
|
||||||
| Returns: | ||||||
| jnp.ndarray: The loaded JAX array. | ||||||
| """ | ||||||
| with open(filename, "rb") as f: | ||||||
| data = f.read() | ||||||
| array = jnp.frombuffer(data, dtype=dtype).reshape(shape) | ||||||
| return array | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -581,7 +581,7 @@ def sort_chunks_by_index( | |
| return _sort_chunks_by_index(inp, split_sizes, sorted_indices) | ||
|
|
||
|
|
||
| @partial(jax.custom_vjp, nondiff_argnums=(1, 2)) | ||
| @jax.custom_vjp | ||
| def _sort_chunks_by_index( | ||
| inp: jnp.ndarray, | ||
| split_sizes: jnp.ndarray, | ||
|
|
@@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule( | |
| inp: jnp.ndarray, | ||
| split_sizes: jnp.ndarray, | ||
| sorted_indices: jnp.ndarray, | ||
| ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]: | ||
| ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]: | ||
| """Forward pass rule for sort_chunks_by_index.""" | ||
| # Validate input dimensions | ||
| assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D" | ||
|
|
@@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule( | |
| ) | ||
|
|
||
| # Return (primals, residuals) | ||
| residuals = (row_id_map, num_tokens, hidden_size) | ||
| # Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums | ||
| residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size) | ||
| return (output, row_id_map), residuals | ||
|
|
||
|
|
||
| def _sort_chunks_by_index_bwd_rule( | ||
| _split_sizes: jnp.ndarray, | ||
| _sorted_indices: jnp.ndarray, | ||
| residuals: Tuple[jnp.ndarray, int, int], | ||
| residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int], | ||
| g: Tuple[jnp.ndarray, jnp.ndarray], | ||
| ) -> Tuple[jnp.ndarray]: | ||
| ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: | ||
| """Backward pass rule for sort_chunks_by_index.""" | ||
| row_id_map, num_tokens, hidden_size = residuals | ||
| row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals | ||
| output_grad, _ = g | ||
|
|
||
| # Backward: reverse the sort | ||
|
|
@@ -642,7 +641,13 @@ def _sort_chunks_by_index_bwd_rule( | |
| is_forward=False, | ||
| ) | ||
|
|
||
| return (inp_grad,) | ||
| # Return gradients for all inputs: (inp, split_sizes, sorted_indices) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not following why the split_sizes and sorted_indices can't be non-differentiable args, can you let me know the reasoning. Thanks! |
||
| # split_sizes and sorted_indices are integer arrays, so their gradients are zeros | ||
| # with matching dtype (use float32 as a safe default for index arrays) | ||
| split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment says "use float32 as safe default" but code preserves original dtype (likely int32/int64) - comment is misleading |
||
| sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype) | ||
|
|
||
| return (inp_grad, split_sizes_grad, sorted_indices_grad) | ||
|
|
||
|
|
||
| _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to copy the tokens/overallocated space beyond the valid range to the output? Or can we skip those and leave the output uninitialized since they are beyond the valid range and won't be used downstream?
Maybe this could be achieved by using a
maskarg to load/store that is dependent ontl.sum(input_split_sizes)?https://triton-lang.org/main/python-api/generated/triton.language.store.html#triton.language.store