-
Notifications
You must be signed in to change notification settings - Fork 636
Initial commit to pass scale as Tensor for multi_tensor_scale op #2594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Initial commit to pass scale as Tensor for multi_tensor_scale op #2594
Conversation
b2a5ae5 to
4081afc
Compare
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…sed but not actually enabled Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…s is passed but not actually enabled" This reverts commit 74a9bcc. Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
4081afc to
cfd4370
Compare
Greptile OverviewGreptile Summaryadded tensor-based variant of Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Python as Python Code
participant PyBind as PyTorch Binding
participant Wrapper as C++ Wrapper
participant CUDA as CUDA Kernel
Python->>PyBind: multi_tensor_scale_tensor(chunk_size, noop_flag, tensor_lists, scale)
PyBind->>Wrapper: multi_tensor_scale_tensor_cuda(...)
Note over Wrapper: Convert PyTorch tensors<br/>to TE tensors
Wrapper->>CUDA: nvte_multi_tensor_scale_tensor_cuda(scale as NVTETensor)
Note over CUDA: Extract scale_tensor->data.dptr<br/>Cast to float*
CUDA->>CUDA: multi_tensor_scale_tensor_cuda(scale as float*)
Note over CUDA: Launch ScalePtrFunctor kernel<br/>with scale pointer
CUDA->>CUDA: ScalePtrFunctor reads *scale_ptr<br/>Applies scaling to tensors
CUDA-->>Python: Return scaled tensors
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 3 comments
| auto scale_cu = makeTransformerEngineTensor(scale); | ||
| auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = | ||
| makeTransformerEngineTensorList(tensor_lists); | ||
| std::cout << "multi_tensor_scale_cuda TENSOR\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug print statement left in production code
| std::cout << "multi_tensor_scale_cuda TENSOR\n"; |
| Tensor *scale_tensor = convertNVTETensorCheck(scale); | ||
| multi_tensor_scale::multi_tensor_scale_tensor_cuda( | ||
| chunk_size, *convertNVTETensorCheck(noop_flag), | ||
| convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), | ||
| reinterpret_cast<float *>(scale_tensor->data.dptr), stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validate scale tensor dtype and size before casting. should check that:
- dtype is
kFloat32 numel() == 1(single element tensor)
Tensor *scale_tensor = convertNVTETensorCheck(scale);
NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32,
"scale tensor must be float32, got ", to_string(scale_tensor->dtype()));
NVTE_CHECK(scale_tensor->numel() == 1,
"scale tensor must have exactly 1 element, got ", scale_tensor->numel());
multi_tensor_scale::multi_tensor_scale_tensor_cuda(
| // I'd like this kernel to propagate infs/nans. | ||
| // if(*noop_gmem == 1) | ||
| // return; | ||
| float scale = *scale_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all threads in a block read from the same device memory location without synchronization - could cause redundant memory traffic but functionally correct. consider caching in shared memory or using __ldg() for read-only cache optimization
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: