-
Notifications
You must be signed in to change notification settings - Fork 638
Initial commit to pass scale as Tensor for multi_tensor_scale op #2594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vasunvidia
wants to merge
4
commits into
NVIDIA:main
Choose a base branch
from
vasunvidia:vrengasamy/multi_tensor_scale_cg
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
87448f2
Initial commit to pass scale as Tensor for multi_tensor_scale op
vasunvidia 0d51679
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fc0fe49
Enable capturable mode for optimizer if store_param_remainders is pas…
vasunvidia cfd4370
Revert "Enable capturable mode for optimizer if store_param_remainder…
vasunvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,6 +102,75 @@ struct ScaleFunctor { | |
| } | ||
| }; | ||
|
|
||
| template <typename in_t, typename out_t> | ||
| struct ScalePtrFunctor { | ||
| __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, | ||
| TensorListMetadata<2> &tl, // NOLINT(*) | ||
| float *scale_ptr) { | ||
| // I'd like this kernel to propagate infs/nans. | ||
| // if(*noop_gmem == 1) | ||
| // return; | ||
| float scale = *scale_ptr; | ||
| int tensor_loc = tl.block_to_tensor[blockIdx.x]; | ||
| int chunk_idx = tl.block_to_chunk[blockIdx.x]; | ||
| int n = tl.sizes[tensor_loc]; | ||
|
|
||
| in_t *in = reinterpret_cast<in_t *>(tl.addresses[0][tensor_loc]); | ||
| in += chunk_idx * chunk_size; | ||
|
|
||
| out_t *out = reinterpret_cast<out_t *>(tl.addresses[1][tensor_loc]); | ||
| out += chunk_idx * chunk_size; | ||
|
|
||
| n -= chunk_idx * chunk_size; | ||
|
|
||
| bool finite = true; | ||
| in_t r_in[ILP]; | ||
| out_t r_out[ILP]; | ||
|
|
||
| // to make things simple, we put aligned case in a different code path | ||
| if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { | ||
| for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; | ||
| i_start += blockDim.x) { | ||
| // load | ||
| load_store(r_in, in, 0, i_start); | ||
| #pragma unroll | ||
| for (int ii = 0; ii < ILP; ii++) { | ||
| r_out[ii] = static_cast<float>(r_in[ii]) * scale; | ||
| finite = finite && isfinite(static_cast<float>(r_in[ii])); | ||
| } | ||
| // store | ||
| load_store(out, r_out, i_start, 0); | ||
| } | ||
| } else { | ||
| // Non-divergent exit condition for __syncthreads, not necessary here | ||
| for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { | ||
| #pragma unroll | ||
| for (int ii = 0; ii < ILP; ii++) { | ||
| r_in[ii] = 0.f; | ||
| int i = i_start + threadIdx.x + ii * blockDim.x; | ||
| if (i < n && i < chunk_size) r_in[ii] = in[i]; | ||
| } | ||
| // note for clarification to future michael: | ||
| // From a pure memory dependency perspective, there's likely no point unrolling | ||
| // the write loop, since writes just fire off once their LDGs arrive. | ||
| // Put another way, the STGs are dependent on the LDGs, but not on each other. | ||
| // There is still compute ILP benefit from unrolling the loop though. | ||
| #pragma unroll | ||
| for (int ii = 0; ii < ILP; ii++) { | ||
| r_out[ii] = static_cast<float>(r_in[ii]) * scale; | ||
| finite = finite && isfinite(static_cast<float>(r_in[ii])); | ||
| } | ||
| #pragma unroll | ||
| for (int ii = 0; ii < ILP; ii++) { | ||
| int i = i_start + threadIdx.x + ii * blockDim.x; | ||
| if (i < n && i < chunk_size) out[i] = r_out[ii]; | ||
| } | ||
| } | ||
| } | ||
| if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. | ||
| } | ||
| }; | ||
|
|
||
| void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, | ||
| std::vector<std::vector<Tensor *>> tensor_lists, float scale, | ||
| cudaStream_t stream) { | ||
|
|
@@ -114,6 +183,18 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, | |
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
||
| void multi_tensor_scale_tensor_cuda(int chunk_size, Tensor noop_flag, | ||
| std::vector<std::vector<Tensor *>> tensor_lists, float *scale, | ||
| cudaStream_t stream) { | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( | ||
| tensor_lists[0][0]->dtype(), p_in_type, | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( | ||
| tensor_lists[1][0]->dtype(), g_in_type, | ||
| multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, | ||
| ScalePtrFunctor<p_in_type, g_in_type>(), stream, scale);)) | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
||
| } // namespace multi_tensor_scale | ||
| } // namespace transformer_engine | ||
|
|
||
|
|
@@ -127,3 +208,17 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens | |
| chunk_size, *convertNVTETensorCheck(noop_flag), | ||
| convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); | ||
| } | ||
|
|
||
| void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, | ||
| NVTETensor **tensor_lists, const size_t num_tensor_lists, | ||
| const size_t num_tensors_per_list, NVTETensor scale, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda); | ||
| using namespace transformer_engine; | ||
|
|
||
| Tensor *scale_tensor = convertNVTETensorCheck(scale); | ||
| multi_tensor_scale::multi_tensor_scale_tensor_cuda( | ||
| chunk_size, *convertNVTETensorCheck(noop_flag), | ||
| convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), | ||
| reinterpret_cast<float *>(scale_tensor->data.dptr), stream); | ||
|
Comment on lines
+219
to
+223
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validate
|
||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -18,4 +18,17 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, | |||
| num_tensors, scale, at::cuda::getCurrentCUDAStream()); | ||||
| } | ||||
|
|
||||
| void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, | ||||
| std::vector<std::vector<at::Tensor>> tensor_lists, | ||||
| at::Tensor scale) { | ||||
| auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); | ||||
| auto scale_cu = makeTransformerEngineTensor(scale); | ||||
| auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = | ||||
| makeTransformerEngineTensorList(tensor_lists); | ||||
| std::cout << "multi_tensor_scale_cuda TENSOR\n"; | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. debug print statement left in production code
Suggested change
|
||||
| nvte_multi_tensor_scale_tensor_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), | ||||
| num_lists, num_tensors, scale_cu.data(), | ||||
| at::cuda::getCurrentCUDAStream()); | ||||
| } | ||||
|
|
||||
| } // namespace transformer_engine::pytorch | ||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all threads in a block read from the same device memory location without synchronization - could cause redundant memory traffic but functionally correct. consider caching in shared memory or using
__ldg()for read-only cache optimizationNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!