Skip to content

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Feb 6, 2026

Description

Extracts the python pieces of GroupedTensor infrastructure from #2600. Since this is mainly focused on creation of weights as a single GroupedTensor and exposing them as multiple QuantizedTensors for PyTorch, this portion does not need to be graph capturable.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Expose a python GroupedTensor class.
  • Integrate GroupedTensor into GroupedLinear such that the parameters are contiguous.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR introduces a new Python-side GroupedTensor storage container that allocates a single contiguous buffer (plus associated scale/amax buffers for quantized formats) and exposes per-tensor views back as QuantizedTensorStorage/QuantizedTensor objects. GroupedLinear is updated to optionally re-register its weight parameters as views into a shared GroupedTensor so grouped GEMMs can benefit from contiguous layout; storage classes add copy_from_storage() to enable copying already-quantized weights into the grouped backing buffers. New unit tests validate pointer offsets and view semantics for multiple quantization recipes, and sanity tests add contiguity checks for grouped-linear weights.

Two merge-blocking regressions are present in this changeset:

  • The Recipe predicate helpers (delayed/mxfp8/nvfp4/...) were changed to @classmethod implementations that call issubclass(), but the codebase calls these predicates on recipe instances (e.g., in pytorch/quantization.py and modules). This will raise TypeError at runtime.
  • _GroupedLinear.forward’s fp8_calibration block contains redundant nested loops (with index shadowing), causing calibration to run num_gemms times unnecessarily when enabled.

Confidence Score: 2/5

  • This PR is not safe to merge as-is due to a runtime-breaking recipe predicate change and a definite logic bug in fp8 calibration.
  • Score reduced because Recipe.*() predicate methods were changed to classmethods using issubclass, but the codebase calls them on recipe instances, which will raise a TypeError at runtime. Additionally, GroupedLinear’s fp8_calibration block contains nested redundant loops that over-calibrate and add overhead. The rest of the changes (GroupedTensor implementation, copy_from_storage additions, and tests) appear coherent but depend on fixing these issues.
  • transformer_engine/common/recipe/init.py, transformer_engine/pytorch/module/grouped_linear.py

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/grouped_tensor.py Adds new Python GroupedTensor contiguous storage and splitting/quantization logic across multiple quantization recipes; core logic looks consistent but relies on recipe predicate behavior and has type annotation/API mismatches.
transformer_engine/pytorch/module/grouped_linear.py Registers GroupedTensor-backed weights and adds make_grouped_weights; contains a definite nested-loop bug in fp8_calibration that repeats calibration work num_gemms times.
transformer_engine/common/recipe/init.py Changes Recipe predicate helpers to classmethods using issubclass; breaks existing instance-based call sites (e.g., recipe.delayed()) with TypeError unless all callers updated.
transformer_engine/pytorch/quantized_tensor.py Adds QuantizedTensorStorage.copy_from_storage abstract method used by GroupedLinear when grouping pre-quantized weights; change is straightforward but requires all concrete storages to implement it (done in this PR).
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Implements copy_from_storage for Float8TensorStorage to enable raw buffer copies into grouped storage; logic is simple and type-checked.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Implements copy_from_storage for MXFP8TensorStorage (including fp8_dtype and scale-layout checks) so grouped weights can copy quantized buffers.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Implements copy_from_storage for NVFP4TensorStorage (including fp4_dtype and scale-layout checks) to support grouped storage copies.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Implements copy_from_storage for Float8BlockwiseQTensorStorage to support grouped weight storage copies across block-scaled FP8.
transformer_engine/pytorch/tensor/float8_tensor.py Minor quantizer/tensor adjustments to support grouped storage integration; no obvious functional regressions in the changed areas.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Updates MXFP8 quantizer/tensor utilities used by GroupedTensor scale shape computations; changes appear consistent with existing usage.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Small updates related to NVFP4 quantizer/tensor utilities (e.g., scale shape helpers) used by GroupedTensor; no clear regressions found.
transformer_engine/pytorch/tensor/storage/init.py Exports GroupedTensor from storage package; simple namespace change.
tests/pytorch/test_grouped_tensor.py Adds comprehensive unit tests for GroupedTensor construction/splitting/offsets across quantization recipes, including NVFP4 packed offsets.
tests/pytorch/test_sanity.py Extends sanity tests to assert grouped linear weights are contiguous for supported recipes; relies on recipe predicate helpers and skips NVFP4.

Sequence Diagram

sequenceDiagram
participant User as User Code
participant QMI as quantized_model_init/autocast
participant GL as GroupedLinear
participant GT as GroupedTensor
participant Q as Quantizer

User->>QMI: enter quantized_model_init(enabled, recipe)
QMI->>GL: instantiate GroupedLinear
GL->>GL: reset_parameters()
GL->>GL: make_grouped_weights()
GL->>Q: _get_weight_quantizers()
GL->>Q: _get_compatible_recipe()
alt recipe supports grouping
  GL->>GT: make_grouped_tensor_with_shapes(...)
  GT->>GT: allocate contiguous buffers
  GT->>GT: split_into_quantized_tensors()
  GL->>GT: copy_from_storage()/copy_()
  GL->>GL: register_parameter(weight{i})
else delayed/current scaling
  GL->>GL: skip grouping
end
User->>GL: forward(inp, m_splits)
GL->>Q: set_usage(rowwise/columnwise)
GL->>GL: tex.split_quantize(inp_view, m_splits)
GL->>GL: get_weight_workspace(...)
GL->>GL: general_grouped_gemm(...)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@ksivaman ksivaman added the MoE label Feb 6, 2026
from .nvfp4_tensor_storage import NVFP4TensorStorage


class GroupedTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to put everything within a single class. We should have an abstract base class (GroupedTensorBase) and concrete classes like GroupedTensor (or UnquantizedGroupTensor?), MXFP8GroupedTensor, NVFP4GroupedTensor. The giant-pile-of-attrs design results in ugly implemenations (like the if-else blocks in make_grouped_tensor) and it generalizes poorly (columnwise_data is treated very differently between FP8 and MXFP8, enough that giving them the same name is questionable). We do use this design in the C++ grouped tensor class, but that should be viewed as a short-term expedient and not a long-term design (#2388 (comment)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ultimately depends on what we want to optimize for. If we believe that the majority of things we are going to write is going to be here is about "grouped" functionality that does not really care about the underlying format (or stuff where we could delegate that decision to C++ which has the full knowledge of the quantizer type and could implement things without huge if/else blocks) then it makes sense to have a single class here. If we believe that the majority of the functionality will be dependent on the quantization format then I agree that we should split this into multiple classes.
@ksivaman Can you comment on that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GroupedTensor in python should be a truthful copy of the C++ grouped tensor, so I do think it's okay to have a single class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This GroupedTensor in python is used mostly as a storage class and is basically a copy of C GroupedTensor with some additional functionality for weight initialization and convenience. I think it's best to keep it simple and avoid over engineering at this stage. In the next steps when we implement a QuantizedGroupedTensor, say for FSDP2 + quantized parameter support, we could revisit if a small refactor would be helpful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree matching the C++ class is reasonable for now so we can meet deadlines. In the long term, we should refactor both the C++ and Python class to have proper polymorphism. This will be a painful process.

@ksivaman ksivaman marked this pull request as draft February 6, 2026 21:28
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have gone for single quantizer, we should remove delayed scaling recipe & per-tensor current scaling for now since their quantizers are not stateless.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we discussed offline, but actually now that I think about it, this is used for FP8 parameters creation, so we cannot simply un-support recipes here. The correct method is to probably use multiple quantizers, or at least have a way for the user to supply multiple quantizers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having a single quantizer without disabling unsupported recipes will make it numerically incorrect for delayed scaling & per-tensor current scaling right? since they do need multiple quantizer instances with hold multiple stateful tensors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two approaches (also see #2654 (comment)):

  • Make contiguous weights an opt-in feature. The default behavior should have discrete weights for backward compatibility.
  • Raise a warning if running with FP8 DS so that users are aware there may be correctness issues.

columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.float32, device=device
)
elif quantizer._get_compatible_recipe().float8_current_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_current_scaling can work with GroupedTensor once we refactored its implementation to remove the amax tensor out of its quantizer. Then it will be safe to put a single quantizer into the grouped tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above^

result.append(tensor)

# Delayed scaling or current scaling (both use Float8TensorStorage)
elif recipe.delayed() or recipe.float8_current_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's assert an error for this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above^

dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: Intent of this API is to create grouped tensor with variable first_dims/last_dims, so we can write that in the comment, since this is not going to be used to create weights.

Also the API can be named to make_grouped_tensor_graph_safe? So, people know this API is safe to use within a forward/backward of a module which we need to be cuda graphable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is actually used to create the weights and is not graph safe (for now), which is fine as it's used 1 time during creation.

Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought make_grouped_tensor_with_shapes is used create weights. Since weight's shapes are going to be constant. Whats the intent of make_grouped_tensor_with_shapes then?

And whats the API we are going to be using to create inputs? Dont we need graph safe for that one?

torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype),
torch.cumsum(first_dims * logical_last_dim, dim=0),
]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the above comment to have single kernel and am not sure what your plan is to implement that.
But with torch op you can avoid one memory op using

tensor_offsets = torch.empty(num_tensors + 1, device=first_dims.device, dtype=first_dims.dtype)
torch.cumsum(first_dims * logical_last_dim, dim=0, out=tensor_offsets[1:])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.empty would do garbage initialization whereas we need tensor_offsets[0] to be explicitly 0, so either way we'd have to do multiple kernels if using pytorch ops. That's why the plan is to later add a small cuda kernel so that we can call it from the C++ extensions and also for Jax as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh my mistake using torch.zeros instead of torch.empty should do the trick. Sure cuda kernel later sounds good.

ksivaman and others added 3 commits February 9, 2026 23:20
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci L0

@ksivaman ksivaman marked this pull request as ready for review February 10, 2026 01:08
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +909 to +935
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Quantize given tensors into quantized tensors with underlying
storage allocated in a GroupedTensor.
"""

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=len(tensors),
shape=[t.shape for t in tensors],
quantizer=quantizer,
device=device,
dtype=dtype,
)

grouped_tensor.quantize(tensors, noop_flag=noop_flag)

return grouped_tensor

def quantize(
self,
tensors: List[torch.Tensor],
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broken create_and_quantize API
create_and_quantize is annotated/declared as returning a Tuple[QuantizedTensorStorage, ...] and taking tensors: int, but the implementation uses len(tensors) / iterates tensors and returns the GroupedTensor instance (return grouped_tensor). Any caller relying on the annotated contract (tuple of quantized tensors) will break, and the tensors: int annotation is incompatible with the actual usage.

This should either return the grouped tensor’s split/quantized tensors (matching the annotation), or update the signature/return type to reflect that it returns a GroupedTensor and expects an iterable of tensors.

Comment on lines +469 to +479
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - one per tensor
scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP8 columnwise scale size
In the MXFP8 columnwise_scale_inv_offsets allocation loop, scale_inv_shape = quantizer.get_scale_shape(s, False) is used even though this is the columnwise scale buffer. Later, split_into_quantized_tensors views columnwise_scale_inv with get_scale_shape(tensor_shape, True) (grouped_tensor.py:744-747). This mismatch will allocate the wrong number of elements and lead to incorrect views/out-of-bounds when columnwise scaling is enabled.

Use quantizer.get_scale_shape(s, True) when computing columnwise_scale_inv_offsets/total_columnwise_scale_elements.

Comment on lines +266 to +291
def __repr__(self) -> str:
"""String representation of the GroupedTensor."""
return (
f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()})"
)

def __str__(self) -> str:
"""User-friendly string representation."""
shape_info = []
if self.all_same_shape():
shape_info.append("uniform shape")
else:
if not self.all_same_first_dim():
shape_info.append("varying first dim")
if not self.all_same_last_dim():
shape_info.append("varying last dim")

return (
f"GroupedTensor with {self.num_tensors} tensors "
f"({', '.join(shape_info) if shape_info else 'uniform'}), "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those different?

)

@staticmethod
def make_grouped_tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the fact that this is on the Python side and not C++ is intentional or a TODO? What is the actual usage of this call? Is it just a helper for the weights creation and meant to be changed later?

grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors()
return grouped_tensor

def split_into_quantized_tensors(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this mostly for debugging? In general it would be good to revisit the docs for the functions and indicate which ones we expect to be used in the typical case (and e.g. are graph safe) and which ones are for debug/not performant/not graph safe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes about the docs. But this is not only a debugging function but also used for creating different parameters out of the Grouped Tensor to be used as weights.

@staticmethod
def make_grouped_tensor_with_shapes(
num_tensors: int,
shape: List[Tuple[int, int]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's call it shapes

offsets = tensor_offsets.tolist()
first_dims_list = first_dims.tolist()
for i in range(num_tensors):
shape.append((first_dims_list[i], logical_last_dim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make this part consistent with the C++ grouped tensor creation. ie. making them both graph safe.

weight_quantizers = self._get_weight_quantizers()

# Create the weight storage.
grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes(
Copy link
Collaborator

@zhongbozhu zhongbozhu Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can grouped linear reference this grouped_weights as a whole and trigger the grouped gemm?
maybe we need self.grouped_weights = grouped_weights here and then pass the reference of this variable to the grouped_linear_forward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those changes will be part of this PR @zhongbozhu
#2669

grouped_weights.quantized_tensors[i].copy_(weights[i])

# Re-register the grouped weights as parameters.
for i in range(self.num_gemms):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one TODO item for the future reference: we might need to support a toggle to support registering the grouped parameters as a whole


self.set_tensor_parallel_attributes(defer_init=defer_init)

def reset_parameters(self, defer_init=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we adjust the order of the functions here? like put helper functions on top and caller in the bottom

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping self.set_tensor_parallel_attributes(defer_init=defer_init) at the end of reset_parameters is ideal. Since otherwise these attributes would be set again during meta device init after reset_parameters is called

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs to be the last step such that those TP specific attributes are retained

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@ksivaman
Copy link
Member Author

/te-ci

ptrendx
ptrendx previously approved these changes Feb 10, 2026
Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditional approval - we will be iterating on it in the subsequent PRs.

Comment on lines +118 to +119
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should handle the edge case that the tensor buffer don't match, or else we won't have confidence that the tensor is safe to read after this function:

Suggested change
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
if dst is not None:
if src_tensor is None:
raise RuntimeError(
f"Attempting to copy from {src.__class__.__name__} "
f"to {self.__class__.__name__}, but data buffers are not compatible."
)
dst.copy_(src_tensor)

Alternatively, we could manually transpose the missing data or dequantize-requantize.

We should also do something similar for the other tensor classes.

grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=self.num_gemms,
shape=[(self.out_features, self.in_features)] * self.num_gemms,
quantizer=weight_quantizers[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will subtly break compatibility with FP8 DS. If we're not willing to fully deprecate FP8 DS, maybe we should raise a warning that we are prototyping an experimental design and that results may not be correct.

That sucks, so maybe we should add an option for contiguous weights and set it false by default. That way we can maintain backward compatibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also break FP8 CS too which is still being actively used. We should just skip creating grouped weight for these two recipes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I have disabled the 2 recipe with Grouped weight. But for the next PR, I will make this feature accessible only via an envvar.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +91 to +94
@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
return issubclass(cls, NVFP4BlockScaling)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broken recipe predicate calls

These recipe predicates were changed to @classmethod + issubclass(cls, ...), which will raise TypeError if any call site passes a recipe instance instead of a class (e.g., Recipe()), since issubclass() requires a class. This PR already has call sites like quantizer._get_compatible_recipe().mxfp8() that only work if _get_compatible_recipe() returns a class; if it ever returns an instance, this becomes a runtime error. Consider either keeping these as instance methods (isinstance) or making _get_compatible_recipe()’s contract unambiguously return a type and enforcing it at the boundary.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +91 to +119
@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
return issubclass(cls, NVFP4BlockScaling)

def mxfp8(self):
@classmethod
def mxfp8(cls):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
return issubclass(cls, MXFP8BlockScaling)

def delayed(self):
@classmethod
def delayed(cls):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
return issubclass(cls, DelayedScaling)

def float8_current_scaling(self):
@classmethod
def float8_current_scaling(cls):
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
return issubclass(cls, Float8CurrentScaling)

def float8_per_tensor_scaling(self):
@classmethod
def float8_per_tensor_scaling(cls):
"""Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling))
return issubclass(cls, (DelayedScaling, Float8CurrentScaling))

def float8_block_scaling(self):
@classmethod
def float8_block_scaling(cls):
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
return issubclass(cls, Float8BlockScaling)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recipe predicates now classmethods

Recipe.delayed()/mxfp8()/nvfp4()/... were changed to @classmethod implementations that call issubclass(cls, ...). Many call sites pass instances (e.g., recipe = FP8GlobalStateManager.get_fp8_recipe() in pytorch/quantization.py:1017 and throughout modules), so recipe.delayed() will run with cls as an instance and issubclass() will raise TypeError: issubclass() arg 1 must be a class. Unless all call sites were updated to use recipe.__class__.delayed() (or similar), this is a hard runtime break.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

transformer_engine/pytorch/module/grouped_linear.py
FP8 calibration loops nested

In _GroupedLinear.forward, the fp8_calibration block has for i in range(num_gemms) repeated three times with the same loop variable (lines ~219-226). As written, the outer loop causes the input_quantizers[i].calibrate(...) / weight_quantizers[i].calibrate(...) passes to run num_gemms times each, and the inner for i ... shadows the outer index. This will over-calibrate (and adds avoidable overhead) whenever fp8_calibration=True.

@ksivaman ksivaman merged commit ac81c85 into NVIDIA:main Feb 11, 2026
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants