-
Notifications
You must be signed in to change notification settings - Fork 635
[PyTorch] Python GroupedTensor
#2654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile OverviewGreptile SummaryThis PR introduces a new Python-side Two merge-blocking regressions are present in this changeset:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant QMI as quantized_model_init/autocast
participant GL as GroupedLinear
participant GT as GroupedTensor
participant Q as Quantizer
User->>QMI: enter quantized_model_init(enabled, recipe)
QMI->>GL: instantiate GroupedLinear
GL->>GL: reset_parameters()
GL->>GL: make_grouped_weights()
GL->>Q: _get_weight_quantizers()
GL->>Q: _get_compatible_recipe()
alt recipe supports grouping
GL->>GT: make_grouped_tensor_with_shapes(...)
GT->>GT: allocate contiguous buffers
GT->>GT: split_into_quantized_tensors()
GL->>GT: copy_from_storage()/copy_()
GL->>GL: register_parameter(weight{i})
else delayed/current scaling
GL->>GL: skip grouping
end
User->>GL: forward(inp, m_splits)
GL->>Q: set_usage(rowwise/columnwise)
GL->>GL: tex.split_quantize(inp_view, m_splits)
GL->>GL: get_weight_workspace(...)
GL->>GL: general_grouped_gemm(...)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| from .nvfp4_tensor_storage import NVFP4TensorStorage | ||
|
|
||
|
|
||
| class GroupedTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a good idea to put everything within a single class. We should have an abstract base class (GroupedTensorBase) and concrete classes like GroupedTensor (or UnquantizedGroupTensor?), MXFP8GroupedTensor, NVFP4GroupedTensor. The giant-pile-of-attrs design results in ugly implemenations (like the if-else blocks in make_grouped_tensor) and it generalizes poorly (columnwise_data is treated very differently between FP8 and MXFP8, enough that giving them the same name is questionable). We do use this design in the C++ grouped tensor class, but that should be viewed as a short-term expedient and not a long-term design (#2388 (comment)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ultimately depends on what we want to optimize for. If we believe that the majority of things we are going to write is going to be here is about "grouped" functionality that does not really care about the underlying format (or stuff where we could delegate that decision to C++ which has the full knowledge of the quantizer type and could implement things without huge if/else blocks) then it makes sense to have a single class here. If we believe that the majority of the functionality will be dependent on the quantization format then I agree that we should split this into multiple classes.
@ksivaman Can you comment on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think GroupedTensor in python should be a truthful copy of the C++ grouped tensor, so I do think it's okay to have a single class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This GroupedTensor in python is used mostly as a storage class and is basically a copy of C GroupedTensor with some additional functionality for weight initialization and convenience. I think it's best to keep it simple and avoid over engineering at this stage. In the next steps when we implement a QuantizedGroupedTensor, say for FSDP2 + quantized parameter support, we could revisit if a small refactor would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree matching the C++ class is reasonable for now so we can meet deadlines. In the long term, we should refactor both the C++ and Python class to have proper polymorphism. This will be a painful process.
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.uint8, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().delayed(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have gone for single quantizer, we should remove delayed scaling recipe & per-tensor current scaling for now since their quantizers are not stateless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what we discussed offline, but actually now that I think about it, this is used for FP8 parameters creation, so we cannot simply un-support recipes here. The correct method is to probably use multiple quantizers, or at least have a way for the user to supply multiple quantizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having a single quantizer without disabling unsupported recipes will make it numerically incorrect for delayed scaling & per-tensor current scaling right? since they do need multiple quantizer instances with hold multiple stateful tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two approaches (also see #2654 (comment)):
- Make contiguous weights an opt-in feature. The default behavior should have discrete weights for backward compatibility.
- Raise a warning if running with FP8 DS so that users are aware there may be correctness issues.
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.float32, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().float8_current_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float8_current_scaling can work with GroupedTensor once we refactored its implementation to remove the amax tensor out of its quantizer. Then it will be safe to put a single quantizer into the grouped tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above^
| result.append(tensor) | ||
|
|
||
| # Delayed scaling or current scaling (both use Float8TensorStorage) | ||
| elif recipe.delayed() or recipe.float8_current_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's assert an error for this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above^
| dtype: Optional[torch.dtype] = None, | ||
| ) -> GroupedTensor: | ||
| """ | ||
| Create a GroupedTensor for storing multiple weight tensors of the same shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment: Intent of this API is to create grouped tensor with variable first_dims/last_dims, so we can write that in the comment, since this is not going to be used to create weights.
Also the API can be named to make_grouped_tensor_graph_safe? So, people know this API is safe to use within a forward/backward of a module which we need to be cuda graphable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API is actually used to create the weights and is not graph safe (for now), which is fine as it's used 1 time during creation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought make_grouped_tensor_with_shapes is used create weights. Since weight's shapes are going to be constant. Whats the intent of make_grouped_tensor_with_shapes then?
And whats the API we are going to be using to create inputs? Dont we need graph safe for that one?
| torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), | ||
| torch.cumsum(first_dims * logical_last_dim, dim=0), | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the above comment to have single kernel and am not sure what your plan is to implement that.
But with torch op you can avoid one memory op using
tensor_offsets = torch.empty(num_tensors + 1, device=first_dims.device, dtype=first_dims.dtype)
torch.cumsum(first_dims * logical_last_dim, dim=0, out=tensor_offsets[1:])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.empty would do garbage initialization whereas we need tensor_offsets[0] to be explicitly 0, so either way we'd have to do multiple kernels if using pytorch ops. That's why the plan is to later add a small cuda kernel so that we can call it from the C++ extensions and also for Jax as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh my mistake using torch.zeros instead of torch.empty should do the trick. Sure cuda kernel later sounds good.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 2 comments
| device: Optional[torch.device] = None, | ||
| dtype: Optional[torch.dtype] = None, | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| ) -> Tuple[QuantizedTensorStorage, ...]: | ||
| """ | ||
| Quantize given tensors into quantized tensors with underlying | ||
| storage allocated in a GroupedTensor. | ||
| """ | ||
|
|
||
| grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=len(tensors), | ||
| shape=[t.shape for t in tensors], | ||
| quantizer=quantizer, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| grouped_tensor.quantize(tensors, noop_flag=noop_flag) | ||
|
|
||
| return grouped_tensor | ||
|
|
||
| def quantize( | ||
| self, | ||
| tensors: List[torch.Tensor], | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| ) -> Tuple[QuantizedTensorStorage, ...]: | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken create_and_quantize API
create_and_quantize is annotated/declared as returning a Tuple[QuantizedTensorStorage, ...] and taking tensors: int, but the implementation uses len(tensors) / iterates tensors and returns the GroupedTensor instance (return grouped_tensor). Any caller relying on the annotated contract (tuple of quantized tensors) will break, and the tensors: int annotation is incompatible with the actual usage.
This should either return the grouped tensor’s split/quantized tensors (matching the annotation), or update the signature/return type to reflect that it returns a GroupedTensor and expects an iterable of tensors.
| if i < num_tensors - 1: | ||
| columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) | ||
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.uint8, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().delayed(): | ||
| if rowwise_usage: | ||
| # Allocate rowwise data buffer (1D flattened, uint8) | ||
| data = torch.empty(total_elements, dtype=torch.uint8, device=device) | ||
| # Scale inverse - one per tensor | ||
| scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MXFP8 columnwise scale size
In the MXFP8 columnwise_scale_inv_offsets allocation loop, scale_inv_shape = quantizer.get_scale_shape(s, False) is used even though this is the columnwise scale buffer. Later, split_into_quantized_tensors views columnwise_scale_inv with get_scale_shape(tensor_shape, True) (grouped_tensor.py:744-747). This mismatch will allocate the wrong number of elements and lead to incorrect views/out-of-bounds when columnwise scaling is enabled.
Use quantizer.get_scale_shape(s, True) when computing columnwise_scale_inv_offsets/total_columnwise_scale_elements.
| def __repr__(self) -> str: | ||
| """String representation of the GroupedTensor.""" | ||
| return ( | ||
| f"GroupedTensor(num_tensors={self.num_tensors}, " | ||
| f"shape={self.shape}, " | ||
| f"logical_shape={self.logical_shape}, " | ||
| f"dtype={self.get_dtype()})" | ||
| ) | ||
|
|
||
| def __str__(self) -> str: | ||
| """User-friendly string representation.""" | ||
| shape_info = [] | ||
| if self.all_same_shape(): | ||
| shape_info.append("uniform shape") | ||
| else: | ||
| if not self.all_same_first_dim(): | ||
| shape_info.append("varying first dim") | ||
| if not self.all_same_last_dim(): | ||
| shape_info.append("varying last dim") | ||
|
|
||
| return ( | ||
| f"GroupedTensor with {self.num_tensors} tensors " | ||
| f"({', '.join(shape_info) if shape_info else 'uniform'}), " | ||
| f"logical_shape={self.logical_shape}, " | ||
| f"dtype={self.get_dtype()}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are those different?
| ) | ||
|
|
||
| @staticmethod | ||
| def make_grouped_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the fact that this is on the Python side and not C++ is intentional or a TODO? What is the actual usage of this call? Is it just a helper for the weights creation and meant to be changed later?
| grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() | ||
| return grouped_tensor | ||
|
|
||
| def split_into_quantized_tensors( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this mostly for debugging? In general it would be good to revisit the docs for the functions and indicate which ones we expect to be used in the typical case (and e.g. are graph safe) and which ones are for debug/not performant/not graph safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes about the docs. But this is not only a debugging function but also used for creating different parameters out of the Grouped Tensor to be used as weights.
| @staticmethod | ||
| def make_grouped_tensor_with_shapes( | ||
| num_tensors: int, | ||
| shape: List[Tuple[int, int]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's call it shapes
| offsets = tensor_offsets.tolist() | ||
| first_dims_list = first_dims.tolist() | ||
| for i in range(num_tensors): | ||
| shape.append((first_dims_list[i], logical_last_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make this part consistent with the C++ grouped tensor creation. ie. making them both graph safe.
| weight_quantizers = self._get_weight_quantizers() | ||
|
|
||
| # Create the weight storage. | ||
| grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how can grouped linear reference this grouped_weights as a whole and trigger the grouped gemm?
maybe we need self.grouped_weights = grouped_weights here and then pass the reference of this variable to the grouped_linear_forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those changes will be part of this PR @zhongbozhu
#2669
| grouped_weights.quantized_tensors[i].copy_(weights[i]) | ||
|
|
||
| # Re-register the grouped weights as parameters. | ||
| for i in range(self.num_gemms): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one TODO item for the future reference: we might need to support a toggle to support registering the grouped parameters as a whole
|
|
||
| self.set_tensor_parallel_attributes(defer_init=defer_init) | ||
|
|
||
| def reset_parameters(self, defer_init=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we adjust the order of the functions here? like put helper functions on top and caller in the bottom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keeping self.set_tensor_parallel_attributes(defer_init=defer_init) at the end of reset_parameters is ideal. Since otherwise these attributes would be set again during meta device init after reset_parameters is called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this needs to be the last step such that those TP specific attributes are retained
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
|
/te-ci |
ptrendx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditional approval - we will be iterating on it in the subsequent PRs.
| if dst is not None and src_tensor is not None: | ||
| dst.copy_(src_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should handle the edge case that the tensor buffer don't match, or else we won't have confidence that the tensor is safe to read after this function:
| if dst is not None and src_tensor is not None: | |
| dst.copy_(src_tensor) | |
| if dst is not None: | |
| if src_tensor is None: | |
| raise RuntimeError( | |
| f"Attempting to copy from {src.__class__.__name__} " | |
| f"to {self.__class__.__name__}, but data buffers are not compatible." | |
| ) | |
| dst.copy_(src_tensor) |
Alternatively, we could manually transpose the missing data or dequantize-requantize.
We should also do something similar for the other tensor classes.
| grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=self.num_gemms, | ||
| shape=[(self.out_features, self.in_features)] * self.num_gemms, | ||
| quantizer=weight_quantizers[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will subtly break compatibility with FP8 DS. If we're not willing to fully deprecate FP8 DS, maybe we should raise a warning that we are prototyping an experimental design and that results may not be correct.
That sucks, so maybe we should add an option for contiguous weights and set it false by default. That way we can maintain backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also break FP8 CS too which is still being actively used. We should just skip creating grouped weight for these two recipes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I have disabled the 2 recipe with Grouped weight. But for the next PR, I will make this feature accessible only via an envvar.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
14 files reviewed, 1 comment
| @classmethod | ||
| def nvfp4(cls): | ||
| """Whether the given recipe is NVFP4 1D block scaling.""" | ||
| return isinstance(self, NVFP4BlockScaling) | ||
| return issubclass(cls, NVFP4BlockScaling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken recipe predicate calls
These recipe predicates were changed to @classmethod + issubclass(cls, ...), which will raise TypeError if any call site passes a recipe instance instead of a class (e.g., Recipe()), since issubclass() requires a class. This PR already has call sites like quantizer._get_compatible_recipe().mxfp8() that only work if _get_compatible_recipe() returns a class; if it ever returns an instance, this becomes a runtime error. Consider either keeping these as instance methods (isinstance) or making _get_compatible_recipe()’s contract unambiguously return a type and enforcing it at the boundary.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
14 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
14 files reviewed, 2 comments
| @classmethod | ||
| def nvfp4(cls): | ||
| """Whether the given recipe is NVFP4 1D block scaling.""" | ||
| return isinstance(self, NVFP4BlockScaling) | ||
| return issubclass(cls, NVFP4BlockScaling) | ||
|
|
||
| def mxfp8(self): | ||
| @classmethod | ||
| def mxfp8(cls): | ||
| """Whether the given recipe is MXFP8 block scaling.""" | ||
| return isinstance(self, MXFP8BlockScaling) | ||
| return issubclass(cls, MXFP8BlockScaling) | ||
|
|
||
| def delayed(self): | ||
| @classmethod | ||
| def delayed(cls): | ||
| """Whether the given recipe is delayed scaling.""" | ||
| return isinstance(self, DelayedScaling) | ||
| return issubclass(cls, DelayedScaling) | ||
|
|
||
| def float8_current_scaling(self): | ||
| @classmethod | ||
| def float8_current_scaling(cls): | ||
| """Whether the given recipe is (per-tensor) current scaling.""" | ||
| return isinstance(self, Float8CurrentScaling) | ||
| return issubclass(cls, Float8CurrentScaling) | ||
|
|
||
| def float8_per_tensor_scaling(self): | ||
| @classmethod | ||
| def float8_per_tensor_scaling(cls): | ||
| """Whether the given recipe is per-tensor scaling.""" | ||
| return isinstance(self, (DelayedScaling, Float8CurrentScaling)) | ||
| return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) | ||
|
|
||
| def float8_block_scaling(self): | ||
| @classmethod | ||
| def float8_block_scaling(cls): | ||
| """Whether the given recipe is float8 blockwise scaling.""" | ||
| return isinstance(self, Float8BlockScaling) | ||
| return issubclass(cls, Float8BlockScaling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipe predicates now classmethods
Recipe.delayed()/mxfp8()/nvfp4()/... were changed to @classmethod implementations that call issubclass(cls, ...). Many call sites pass instances (e.g., recipe = FP8GlobalStateManager.get_fp8_recipe() in pytorch/quantization.py:1017 and throughout modules), so recipe.delayed() will run with cls as an instance and issubclass() will raise TypeError: issubclass() arg 1 must be a class. Unless all call sites were updated to use recipe.__class__.delayed() (or similar), this is a hard runtime break.
Additional Comments (1)
In |
Description
Extracts the python pieces of
GroupedTensorinfrastructure from #2600. Since this is mainly focused on creation of weights as a singleGroupedTensorand exposing them as multipleQuantizedTensors for PyTorch, this portion does not need to be graph capturable.Type of change
Changes
GroupedTensorclass.GroupedTensorintoGroupedLinearsuch that the parameters are contiguous.Checklist: