Skip to content

Conversation

@vasunvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vasunvidia vasunvidia force-pushed the vrengasamy/multi_tensor_scale_cg branch from b2a5ae5 to 4081afc Compare February 11, 2026 19:17
@vasunvidia vasunvidia marked this pull request as ready for review February 11, 2026 19:20
vasunvidia and others added 4 commits February 11, 2026 11:20
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…sed but not actually enabled

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…s is passed but not actually enabled"

This reverts commit 74a9bcc.

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
@vasunvidia vasunvidia force-pushed the vrengasamy/multi_tensor_scale_cg branch from 4081afc to cfd4370 Compare February 11, 2026 19:20
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

added tensor-based variant of multi_tensor_scale operation (multi_tensor_scale_tensor) to accept scale factor as a device tensor instead of host scalar

Key changes:

  • implemented new ScalePtrFunctor CUDA kernel that dereferences scale from device memory
  • added C API function nvte_multi_tensor_scale_tensor_cuda and PyTorch bindings
  • exposed new function through Python API in optimizers module

Issues found:

  • debug print statement left in scale.cpp:28 must be removed
  • missing validation for scale tensor dtype (should be kFloat32) and size (should be 1 element) in scale.cu:219
  • minor performance consideration: scale value read from device memory by all threads could use __ldg() or shared memory for optimization

Confidence Score: 3/5

  • safe to merge after removing debug statement and adding validation checks
  • implementation follows existing patterns and mirrors the scalar version correctly, but contains a debug print statement that must be removed and lacks input validation for the scale tensor that could cause runtime errors if invalid tensors are passed
  • transformer_engine/common/multi_tensor/scale.cu and transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp need attention for the validation and debug statement issues

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp added multi_tensor_scale_tensor_cuda function to accept scale as tensor; contains debug print statement that should be removed
transformer_engine/common/multi_tensor/scale.cu implemented ScalePtrFunctor and multi_tensor_scale_tensor_cuda for tensor-based scaling; missing validation for scale tensor dtype and size

Sequence Diagram

sequenceDiagram
    participant Python as Python Code
    participant PyBind as PyTorch Binding
    participant Wrapper as C++ Wrapper
    participant CUDA as CUDA Kernel
    
    Python->>PyBind: multi_tensor_scale_tensor(chunk_size, noop_flag, tensor_lists, scale)
    PyBind->>Wrapper: multi_tensor_scale_tensor_cuda(...)
    Note over Wrapper: Convert PyTorch tensors<br/>to TE tensors
    Wrapper->>CUDA: nvte_multi_tensor_scale_tensor_cuda(scale as NVTETensor)
    Note over CUDA: Extract scale_tensor->data.dptr<br/>Cast to float*
    CUDA->>CUDA: multi_tensor_scale_tensor_cuda(scale as float*)
    Note over CUDA: Launch ScalePtrFunctor kernel<br/>with scale pointer
    CUDA->>CUDA: ScalePtrFunctor reads *scale_ptr<br/>Applies scaling to tensors
    CUDA-->>Python: Return scaled tensors
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

auto scale_cu = makeTransformerEngineTensor(scale);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
std::cout << "multi_tensor_scale_cuda TENSOR\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug print statement left in production code

Suggested change
std::cout << "multi_tensor_scale_cuda TENSOR\n";

Comment on lines +219 to +223
Tensor *scale_tensor = convertNVTETensorCheck(scale);
multi_tensor_scale::multi_tensor_scale_tensor_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
reinterpret_cast<float *>(scale_tensor->data.dptr), stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate scale tensor dtype and size before casting. should check that:

  • dtype is kFloat32
  • numel() == 1 (single element tensor)
  Tensor *scale_tensor = convertNVTETensorCheck(scale);
  NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32, 
             "scale tensor must be float32, got ", to_string(scale_tensor->dtype()));
  NVTE_CHECK(scale_tensor->numel() == 1,
             "scale tensor must have exactly 1 element, got ", scale_tensor->numel());
  multi_tensor_scale::multi_tensor_scale_tensor_cuda(

// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
float scale = *scale_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all threads in a block read from the same device memory location without synchronization - could cause redundant memory traffic but functionally correct. consider caching in shared memory or using __ldg() for read-only cache optimization

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant