From 00f4591c97580ccbc5d5ac363350dd7fd09d6b15 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 13 Feb 2026 09:32:29 +0100 Subject: [PATCH 1/2] first working commit --- pina/_src/condition/condition_base.py | 16 +- pina/_src/core/trainer.py | 33 +- pina/_src/data/aggregator.py | 58 +++ pina/_src/data/creator.py | 178 ++++++++ pina/_src/data/data_module.py | 620 ++++++-------------------- pina/_src/data/dummy_dataloader.py | 62 +++ pina/_src/problem/abstract_problem.py | 74 +-- 7 files changed, 517 insertions(+), 524 deletions(-) create mode 100644 pina/_src/data/aggregator.py create mode 100644 pina/_src/data/creator.py create mode 100644 pina/_src/data/dummy_dataloader.py diff --git a/pina/_src/condition/condition_base.py b/pina/_src/condition/condition_base.py index b8290d717..44a8af2b7 100644 --- a/pina/_src/condition/condition_base.py +++ b/pina/_src/condition/condition_base.py @@ -9,6 +9,7 @@ from pina._src.condition.condition_interface import ConditionInterface from pina._src.core.graph import LabelBatch from pina._src.core.label_tensor import LabelTensor +from pina._src.data.dummy_dataloader import DummyDataloader class ConditionBase(ConditionInterface): @@ -85,7 +86,8 @@ def automatic_batching_collate_fn(cls, batch): if not batch: return {} instance_class = batch[0].__class__ - return instance_class.create_batch(batch) + batch = instance_class.create_batch(batch) + return batch @staticmethod def collate_fn(batch, condition): @@ -103,7 +105,11 @@ def collate_fn(batch, condition): return data def create_dataloader( - self, dataset, batch_size, shuffle, automatic_batching + self, + dataset, + batch_size, + automatic_batching, + **kwargs, ): """ Create a DataLoader for the condition. @@ -114,14 +120,14 @@ def create_dataloader( :rtype: torch.utils.data.DataLoader """ if batch_size == len(dataset): - pass # will be updated in the near future + return DummyDataloader(dataset) return DataLoader( dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, collate_fn=( partial(self.collate_fn, condition=self) if not automatic_batching else self.automatic_batching_collate_fn ), + batch_size=batch_size, + **kwargs, ) diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py index 7500be537..377b42fac 100644 --- a/pina/_src/core/trainer.py +++ b/pina/_src/core/trainer.py @@ -36,7 +36,7 @@ def __init__( test_size=0.0, val_size=0.0, compile=None, - repeat=None, + batching_mode="common_batch_size", automatic_batching=None, num_workers=None, pin_memory=None, @@ -61,9 +61,9 @@ def __init__( :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. Not supported for python version greater or equal than 3.14. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. Default is ``"common_batch_size"``. ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset @@ -87,7 +87,7 @@ def __init__( train_size=train_size, test_size=test_size, val_size=val_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, compile=compile, ) @@ -127,8 +127,6 @@ def __init__( UserWarning, ) - repeat = repeat if repeat is not None else False - automatic_batching = ( automatic_batching if automatic_batching is not None else False ) @@ -144,7 +142,7 @@ def __init__( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, @@ -182,7 +180,7 @@ def _create_datamodule( test_size, val_size, batch_size, - repeat, + batching_mode, automatic_batching, pin_memory, num_workers, @@ -201,8 +199,9 @@ def _create_datamodule( :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data @@ -232,7 +231,7 @@ def _create_datamodule( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -284,7 +283,7 @@ def _check_input_consistency( train_size, test_size, val_size, - repeat, + batching_mode, automatic_batching, compile, ): @@ -298,8 +297,9 @@ def _check_input_consistency( test dataset. :param float val_size: The percentage of elements to include in the validation dataset. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. @@ -309,8 +309,7 @@ def _check_input_consistency( check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - if repeat is not None: - check_consistency(repeat, bool) + check_consistency(batching_mode, str) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: diff --git a/pina/_src/data/aggregator.py b/pina/_src/data/aggregator.py new file mode 100644 index 000000000..c788132c2 --- /dev/null +++ b/pina/_src/data/aggregator.py @@ -0,0 +1,58 @@ +""" +Aggregator for multiple dataloaders. +""" + + +class _Aggregator: + """ + The class :class:`_Aggregator` is responsible for aggregating multiple + dataloaders into a single iterable object. It supports different batching + modes to accommodate various training requirements. + """ + + def __init__(self, dataloaders, batching_mode): + """ + Initialization of the :class:`_Aggregator` class. + + :param dataloaders: A dictionary mapping condition names to their + respective dataloaders. + :type dataloaders: dict[str, DataLoader] + :param batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. + :type batching_mode: str + """ + self.dataloaders = dataloaders + self.batching_mode = batching_mode + + def __len__(self): + """ + Return the length of the aggregated dataloader. + + :return: The length of the aggregated dataloader. + :rtype: int + """ + return max(len(dl) for dl in self.dataloaders.values()) + + def __iter__(self): + """ + Return an iterator over the aggregated dataloader. + + :return: An iterator over the aggregated dataloader. + :rtype: iterator + """ + if self.batching_mode == "separate_conditions": + for name, dl in self.dataloaders.items(): + for batch in dl: + yield {name: batch} + return + iterators = {name: iter(dl) for name, dl in self.dataloaders.items()} + for _ in range(len(self)): + batch = {} + for name, it in iterators.items(): + try: + batch[name] = next(it) + except StopIteration: + iterators[name] = iter(self.dataloaders[name]) + batch[name] = next(iterators[name]) + yield batch diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py new file mode 100644 index 000000000..b0e6d37c1 --- /dev/null +++ b/pina/_src/data/creator.py @@ -0,0 +1,178 @@ +""" +Module defining the Creator class, responsible for creating dataloaders +for multiple conditions with various batching strategies. +""" + +import torch +from torch.utils.data import RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + + +class _Creator: + """ + The class :class:`_Creator` is responsible for creating dataloaders for + multiple conditions based on specified batching strategies. It supports + different batching modes to accommodate various training requirements. + """ + + def __init__( + self, + batching_mode, + batch_size, + shuffle, + automatic_batching, + num_workers, + pin_memory, + conditions, + ): + """ + Initialization of the :class:`_Creator` class. + + :param batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. + :type batching_mode: str + :param batch_size: The batch size to use for dataloaders. If + ``batching_mode`` is ``"proportional"``, this represents the total + batch size across all conditions. + :type batch_size: int | None + :param shuffle: Whether to shuffle the data in the dataloaders. + :type shuffle: bool + :param automatic_batching: Whether to use automatic batching in the + dataloaders. + :type automatic_batching: bool + :param num_workers: The number of worker processes to use for data + loading. + :type num_workers: int + :param pin_memory: Whether to pin memory in the dataloaders. + :type pin_memory: bool + :param conditions: A dictionary mapping condition names to their + respective condition objects. + :type conditions: dict[str, Condition] + """ + self.batching_mode = batching_mode + self.batch_size = batch_size + self.shuffle = shuffle + self.automatic_batching = automatic_batching + self.num_workers = num_workers + self.pin_memory = pin_memory + self.conditions = conditions + + def _define_sampler(self, dataset, shuffle): + if torch.distributed.is_initialized(): + return DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + return RandomSampler(dataset) + return SequentialSampler(dataset) + + def _compute_batch_sizes(self, datasets): + """ + Compute batch sizes for each condition based on the specified + batching mode. + + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their computed batch + sizes. + :rtype: dict[str, int] + """ + batch_sizes = {} + if self.batching_mode == "common_batch_size": + for name in datasets.keys(): + if self.batch_size is None: + batch_sizes[name] = len(datasets[name]) + else: + batch_sizes[name] = min( + self.batch_size, len(datasets[name]) + ) + return batch_sizes + if self.batching_mode == "proportional": + return self._compute_proportional_batch_sizes(datasets) + if self.batching_mode == "separate_conditions": + for name in datasets.keys(): + condition = self.conditions[name] + if self.batch_size is None: + batch_sizes[name] = len(datasets[name]) + else: + batch_sizes[name] = min( + self.batch_size, len(datasets[name]) + ) + return batch_sizes + raise ValueError(f"Unknown batching mode: {self.batching_mode}") + + def _compute_proportional_batch_sizes(self, datasets): + """ + Compute batch sizes for each condition proportionally based on the + size of their datasets. + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their computed batch + sizes. + :rtype: dict[str, int] + """ + # Compute number of elements per dataset + elements_per_dataset = { + dataset_name: len(dataset) + for dataset_name, dataset in datasets.items() + } + # Compute the total number of elements + total_elements = sum(el for el in elements_per_dataset.values()) + # Compute the portion of each dataset + portion_per_dataset = { + name: el / total_elements + for name, el in elements_per_dataset.items() + } + # Compute batch size per dataset. Ensure at least 1 element per + # dataset. + batch_size_per_dataset = { + name: max(1, int(portion * self.batch_size)) + for name, portion in portion_per_dataset.items() + } + # Adjust batch sizes to match the specified total batch size + tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) + if self.batch_size > tot_el_per_batch: + difference = self.batch_size - tot_el_per_batch + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] += 1 + difference -= 1 + if self.batch_size < tot_el_per_batch: + difference = tot_el_per_batch - self.batch_size + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] -= 1 + difference -= 1 + return batch_size_per_dataset + + def __call__(self, datasets): + """ + Create dataloaders for each condition based on the specified batching + mode. + :param datasets: A dictionary mapping condition names to their + respective datasets. + :type datasets: dict[str, Dataset] + :return: A dictionary mapping condition names to their created + dataloaders. + :rtype: dict[str, DataLoader] + """ + # Compute batch sizes per condition based on batching_mode + batch_sizes = self._compute_batch_sizes(datasets) + dataloaders = {} + for name, dataset in datasets.items(): + dataloaders[name] = self.conditions[name].create_dataloader( + dataset=dataset, + batch_size=batch_sizes[name], + automatic_batching=self.automatic_batching, + sampler=self._define_sampler(dataset, self.shuffle), + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + return dataloaders diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index f45236f0f..b39596eaf 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -12,227 +12,52 @@ from torch.utils.data.distributed import DistributedSampler from pina._src.core.label_tensor import LabelTensor from pina._src.data.dataset import PinaDatasetFactory, PinaTensorDataset +from pina._src.data.creator import _Creator +from pina._src.data.aggregator import _Aggregator -class DummyDataloader: - - def __init__(self, dataset): - """ - Prepare a dataloader object that returns the entire dataset in a single - batch. Depending on the number of GPUs, the dataset is managed - as follows: - - - **Distributed Environment** (multiple GPUs): Divides dataset across - processes using the rank and world size. Fetches only portion of - data corresponding to the current process. - - **Non-Distributed Environment** (single GPU): Fetches the entire - dataset. - - :param PinaDataset dataset: The dataset object to be processed. - - .. note:: - This dataloader is used when the batch size is ``None``. - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if len(dataset) < world_size: - raise RuntimeError( - "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU" - ) - idx, i = [], rank - while i < len(dataset): - idx.append(i) - i += world_size - self.dataset = dataset.fetch_from_idx_list(idx) - else: - self.dataset = dataset.get_all_data() - - def __iter__(self): - return self - - def __len__(self): - return 1 - - def __next__(self): - return self.dataset - - -class Collator: +class _ConditionSubset: """ - This callable class is used to collate the data points fetched from the - dataset. The collation is performed based on the type of dataset used and - on the batching strategy. + This class extends the :class:`torch.utils.data.Subset` class, allowing to + fetch the data from the dataset based on a list of indices. """ - def __init__( - self, max_conditions_lengths, automatic_batching, dataset=None - ): - """ - Initialize the object, setting the collate function based on whether - automatic batching is enabled or not. - - :param dict max_conditions_lengths: ``dict`` containing the maximum - number of data points to consider in a single batch for - each condition. - :param bool automatic_batching: Whether automatic PyTorch batching is - enabled or not. For more information, see the - :class:`~pina.data.data_module.PinaDataModule` class. - :param PinaDataset dataset: The dataset where the data is stored. - """ - - self.max_conditions_lengths = max_conditions_lengths - # Set the collate function based on the batching strategy - # collate_pina_dataloader is used when automatic batching is disabled - # collate_torch_dataloader is used when automatic batching is enabled - self.callable_function = ( - self._collate_torch_dataloader - if automatic_batching - else (self._collate_pina_dataloader) - ) - self.dataset = dataset - - # Set the function which performs the actual collation - if isinstance(self.dataset, PinaTensorDataset): - # If the dataset is a PinaTensorDataset, use this collate function - self._collate = self._collate_tensor_dataset - else: - # If the dataset is a PinaDataset, use this collate function - self._collate = self._collate_graph_dataset - - def _collate_pina_dataloader(self, batch): - """ - Function used to create a batch when automatic batching is disabled. - - :param list[int] batch: List of integers representing the indices of - the data points to be fetched. - :return: Dictionary containing the data points fetched from the dataset. - :rtype: dict - """ - # Call the fetch_from_idx_list method of the dataset - return self.dataset.fetch_from_idx_list(batch) - - def _collate_torch_dataloader(self, batch): - """ - Function used to collate the batch - - :param list[dict] batch: List of retrieved data. - :return: Dictionary containing the data points fetched from the dataset, - collated. - :rtype: dict - """ - - batch_dict = {} - if isinstance(batch, dict): - return batch - conditions_names = batch[0].keys() - # Condition names - for condition_name in conditions_names: - single_cond_dict = {} - condition_args = batch[0][condition_name].keys() - for arg in condition_args: - data_list = [ - batch[idx][condition_name][arg] - for idx in range( - min( - len(batch), - self.max_conditions_lengths[condition_name], - ) - ) - ] - single_cond_dict[arg] = self._collate(data_list) - - batch_dict[condition_name] = single_cond_dict - return batch_dict - - @staticmethod - def _collate_tensor_dataset(data_list): - """ - Function used to collate the data when the dataset is a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param data_list: Elements to be collated. - :type data_list: list[torch.Tensor] | list[LabelTensor] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - """ - - if isinstance(data_list[0], LabelTensor): - return LabelTensor.stack(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.stack(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor ") - - def _collate_graph_dataset(self, data_list): - """ - Function used to collate data when the dataset is a - :class:`~pina.data.dataset.PinaGraphDataset`. - - :param data_list: Elememts to be collated. - :type data_list: list[Data] | list[Graph] - :return: Batch of data. - :rtype: dict + def __init__(self, condition, indices, automatic_batching): + super().__init__() + self.condition = condition + self.indices = indices + self.automatic_batching = automatic_batching - :raises RuntimeError: If the data is not a - :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. - """ - if isinstance(data_list[0], LabelTensor): - return LabelTensor.cat(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.cat(data_list) - if isinstance(data_list[0], Data): - return self.dataset.create_batch(data_list) - raise RuntimeError( - "Data must be Tensors or LabelTensor or pyG " - "torch_geometric.data.Data" - ) + def __len__(self): + return len(self.indices) - def __call__(self, batch): + def __getitem__(self, idx): """ - Perform the collation of data fetched from the dataset. The behavoior - of the function is set based on the batching strategy during class - initialization. + Fetch the data from the dataset based on the list of indices. - :param batch: List of retrieved data or sampled indices. - :type batch: list[int] | list[dict] - :return: Dictionary containing colleted data fetched from the dataset. + :param int idx: The index of the data to be fetched. + :return: The data corresponding to the given index. :rtype: dict """ - - return self.callable_function(batch) - - -class PinaSampler: - """ - This class is used to create the sampler instance based on the shuffle - parameter and the environment in which the code is running. - """ - - def __new__(cls, dataset): - """ - Instantiate and initialize the sampler. - - :param PinaDataset dataset: The dataset from which to sample. - :return: The sampler instance. - :rtype: :class:`torch.utils.data.Sampler` - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - sampler = DistributedSampler(dataset) - else: - sampler = SequentialSampler(dataset) - return sampler + idx = self.indices[idx] + if not self.automatic_batching: + return idx + return self.condition[idx] + + def get_all_data(self): + data = self.condition[self.indices] + if "data" in data and isinstance(data["data"], list): + batch_fn = ( + LabelBatch.from_data_list + if isinstance(data["data"][0], Graph) + else Batch.from_data_list + ) + data["data"] = batch_fn(data["data"]) + data = { + "input": data["data"], + "target": data["data"].y, + } + return data class PinaDataModule(LightningDataModule): @@ -250,7 +75,7 @@ def __init__( val_size=0.1, batch_size=None, shuffle=True, - repeat=False, + batching_mode="separate_conditions", automatic_batching=None, num_workers=0, pin_memory=False, @@ -271,11 +96,9 @@ def __init__( Default is ``None``. :param bool shuffle: Whether to shuffle the dataset before splitting. Default ``True``. - :param bool repeat: If ``True``, in case of batch size larger than the - number of elements in a specific condition, the elements are - repeated until the batch size is reached. If ``False``, the number - of elements in the batch is the minimum between the batch size and - the number of elements in the condition. Default is ``False``. + :param str batching_mode: The batching mode to use. Options are + ``"common_batch_size"``, ``"proportional"``, and + ``"separate_conditions"``. Default is ``"separate_conditions"``. :param automatic_batching: If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful @@ -302,10 +125,11 @@ def __init__( """ super().__init__() + self.problem = problem # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle - self.repeat = repeat + self.batching_mode = batching_mode self.automatic_batching = automatic_batching # If batch size is None, num_workers has no effect @@ -327,41 +151,87 @@ def __init__( self.pin_memory = False else: self.pin_memory = pin_memory - - # Collect data - problem.collect_data() - - # Check if the splits are correct + self.problem.move_discretisation_into_conditions() self._check_slit_sizes(train_size, test_size, val_size) - # Split input data into subsets - splits_dict = {} if train_size > 0: - splits_dict["train"] = train_size self.train_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.train_dataloader = super().train_dataloader if test_size > 0: - splits_dict["test"] = test_size self.test_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.test_dataloader = super().test_dataloader if val_size > 0: - splits_dict["val"] = val_size self.val_dataset = None else: # Use the super method to create the train dataloader which # raises NotImplementedError self.val_dataloader = super().val_dataloader - self.data_splits = self._create_splits( - problem.collected_data, splits_dict + self._create_condition_splits(problem, train_size, test_size, val_size) + self.creator = _Creator( + batching_mode=batching_mode, + batch_size=batch_size, + shuffle=shuffle, + automatic_batching=automatic_batching, + num_workers=num_workers, + pin_memory=pin_memory, + conditions=problem.conditions, ) - self.transfer_batch_to_device = self._transfer_batch_to_device + + @staticmethod + def _check_slit_sizes(train_size, test_size, val_size): + """ + Check if the splits are correct. The splits sizes must be positive and + the sum of the splits must be 1. + + :param float train_size: The size of the training split. + :param float test_size: The size of the testing split. + :param float val_size: The size of the validation split. + + :raises ValueError: If at least one of the splits is negative. + :raises ValueError: If the sum of the splits is different + from 1. + """ + + if train_size < 0 or test_size < 0 or val_size < 0: + raise ValueError("The splits must be positive") + if abs(train_size + test_size + val_size - 1) > 1e-6: + raise ValueError("The sum of the splits must be 1") + + def _create_condition_splits( + self, problem, train_size, test_size, val_size + ): + self.split_idxs = {} + for condition_name, condition in problem.conditions.items(): + len_condition = len(condition) + # Create the indices for shuffling and splitting + indices = ( + torch.randperm(len_condition).tolist() + if self.shuffle + else list(range(len_condition)) + ) + + # Determine split sizes + train_end = int(train_size * len_condition) + test_end = train_end + int(test_size * len_condition) + + # Split indices + train_indices = indices[:train_end] + test_indices = indices[train_end:test_end] + val_indices = indices[test_end:] + splits = {} + splits["train"], splits["test"], splits["val"] = ( + train_indices, + test_indices, + val_indices, + ) + self.split_idxs[condition_name] = splits def setup(self, stage=None): """ @@ -374,209 +244,60 @@ def setup(self, stage=None): :raises ValueError: If the stage is neither "fit" nor "test". """ if stage == "fit" or stage is None: - self.train_dataset = PinaDatasetFactory( - self.data_splits["train"], - max_conditions_lengths=self.find_max_conditions_lengths( - "train" - ), - automatic_batching=self.automatic_batching, - ) - if "val" in self.data_splits.keys(): - self.val_dataset = PinaDatasetFactory( - self.data_splits["val"], - max_conditions_lengths=self.find_max_conditions_lengths( - "val" - ), + print("Sono qui") + self.train_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["train"], automatic_batching=self.automatic_batching, ) - elif stage == "test": - self.test_dataset = PinaDatasetFactory( - self.data_splits["test"], - max_conditions_lengths=self.find_max_conditions_lengths("test"), - automatic_batching=self.automatic_batching, - ) - else: - raise ValueError("stage must be either 'fit' or 'test'.") - - @staticmethod - def _split_condition(single_condition_dict, splits_dict): - """ - Split the condition into different stages. - - :param dict single_condition_dict: The condition to be split. - :param dict splits_dict: The dictionary containing the number of - elements in each stage. - :return: A dictionary containing the split condition. - :rtype: dict - """ - - len_condition = len(single_condition_dict["input"]) - - lengths = [ - int(len_condition * length) for length in splits_dict.values() - ] - - remainder = len_condition - sum(lengths) - for i in range(remainder): - lengths[i % len(lengths)] += 1 - - splits_dict = { - k: max(1, v) for k, v in zip(splits_dict.keys(), lengths) - } - to_return_dict = {} - offset = 0 - - for stage, stage_len in splits_dict.items(): - to_return_dict[stage] = { - k: v[offset : offset + stage_len] - for k, v in single_condition_dict.items() - if k != "equation" - # Equations are NEVER dataloaded + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["train"]) > 0 } - if offset + stage_len >= len_condition: - offset = len_condition - 1 - continue - offset += stage_len - return to_return_dict - - def _create_splits(self, collector, splits_dict): - """ - Create the dataset objects putting data in the correct splits. - - :param Collector collector: The collector object containing the data. - :param dict splits_dict: The dictionary containing the number of - elements in each stage. - :return: The dictionary containing the dataset objects. - :rtype: dict - """ - - # ----------- Auxiliary function ------------ - def _apply_shuffle(condition_dict, len_data): - idx = torch.randperm(len_data) - for k, v in condition_dict.items(): - if k == "equation": - continue - if isinstance(v, list): - condition_dict[k] = [v[i] for i in idx] - elif isinstance(v, LabelTensor): - condition_dict[k] = LabelTensor(v.tensor[idx], v.labels) - elif isinstance(v, torch.Tensor): - condition_dict[k] = v[idx] - else: - raise ValueError(f"Data type {type(v)} not supported") - - # ----------- End auxiliary function ------------ - - split_names = list(splits_dict.keys()) - dataset_dict = {name: {} for name in split_names} - for ( - condition_name, - condition_dict, - ) in collector.items(): - len_data = len(condition_dict["input"]) - if self.shuffle: - _apply_shuffle(condition_dict, len_data) - for key, data in self._split_condition( - condition_dict, splits_dict - ).items(): - dataset_dict[key].update({condition_name: data}) - return dataset_dict - - def _create_dataloader(self, split, dataset): - """ " - Create the dataloader for the given split. - - :param str split: The split on which to create the dataloader. - :param str dataset: The dataset to be used for the dataloader. - :return: The dataloader for the given split. - :rtype: torch.utils.data.DataLoader - """ - # Suppress the warning about num_workers. - # In many cases, especially for PINNs, - # serial data loading can outperform parallel data loading. - warnings.filterwarnings( - "ignore", - message=( - "The '(train|val|test)_dataloader' does not have many workers " - "which may be a bottleneck." - ), - module="lightning.pytorch.trainer.connectors.data_connector", - ) - # Use custom batching (good if batch size is large) - if self.batch_size is not None: - sampler = PinaSampler(dataset) - if self.automatic_batching: - collate = Collator( - self.find_max_conditions_lengths(split), - self.automatic_batching, - dataset=dataset, + print(self.train_datasets) + self.val_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["val"], + automatic_batching=self.automatic_batching, ) - else: - collate = Collator( - None, self.automatic_batching, dataset=dataset + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["val"]) > 0 + } + return + if stage == "test" or stage is None: + self.test_datasets = { + name: _ConditionSubset( + condition, + self.split_idxs[name]["test"], + automatic_batching=self.automatic_batching, ) - return DataLoader( - dataset, - self.batch_size, - collate_fn=collate, - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory, + for name, condition in self.problem.conditions.items() + if len(self.split_idxs[name]["test"]) > 0 + } + else: + raise ValueError( + f"Invalid stage {stage}. Stage must be either 'fit' or 'test'." ) - dataloader = DummyDataloader(dataset) - dataloader.dataset = self._transfer_batch_to_device( - dataloader.dataset, self.trainer.strategy.root_device, 0 - ) - self.transfer_batch_to_device = self._transfer_batch_to_device_dummy - return dataloader - - def find_max_conditions_lengths(self, split): - """ - Define the maximum length for each conditions. - - :param dict split: The split of the dataset. - :return: The maximum length per condition. - :rtype: dict - """ - - max_conditions_lengths = {} - for k, v in self.data_splits[split].items(): - if self.batch_size is None: - max_conditions_lengths[k] = len(v["input"]) - elif self.repeat: - max_conditions_lengths[k] = self.batch_size - else: - max_conditions_lengths[k] = min( - len(v["input"]), self.batch_size - ) - return max_conditions_lengths - - def val_dataloader(self): - """ - Create the validation dataloader. - - :return: The validation dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("val", self.val_dataset) def train_dataloader(self): - """ - Create the training dataloader + print(self.train_datasets) + return _Aggregator( + self.creator(self.train_datasets), + batching_mode="separate_conditions", + ) - :return: The training dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("train", self.train_dataset) + def val_dataloader(self): + print(self.val_datasets) + return _Aggregator( + self.creator(self.val_datasets), batching_mode="separate_conditions" + ) def test_dataloader(self): - """ - Create the testing dataloader - - :return: The testing dataloader - :rtype: torch.utils.data.DataLoader - """ - return self._create_dataloader("test", self.test_dataset) + return _Aggregator( + self.creator(self.test_datasets), + batching_mode="separate_conditions", + ) @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): @@ -591,10 +312,9 @@ def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): :return: The batch transferred to the device. :rtype: list[tuple] """ - return batch - def _transfer_batch_to_device(self, batch, device, dataloader_idx): + def transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. @@ -606,53 +326,7 @@ def _transfer_batch_to_device(self, batch, device, dataloader_idx): :return: The batch transferred to the device. :rtype: list[tuple] """ - - batch = [ - ( - k, - super(LightningDataModule, self).transfer_batch_to_device( - v, device, dataloader_idx - ), - ) - for k, v in batch.items() - ] - - return batch - - @staticmethod - def _check_slit_sizes(train_size, test_size, val_size): - """ - Check if the splits are correct. The splits sizes must be positive and - the sum of the splits must be 1. - - :param float train_size: The size of the training split. - :param float test_size: The size of the testing split. - :param float val_size: The size of the validation split. - - :raises ValueError: If at least one of the splits is negative. - :raises ValueError: If the sum of the splits is different - from 1. - """ - - if train_size < 0 or test_size < 0 or val_size < 0: - raise ValueError("The splits must be positive") - if abs(train_size + test_size + val_size - 1) > 1e-6: - raise ValueError("The sum of the splits must be 1") - - @property - def input(self): - """ - Return all the input points coming from all the datasets. - - :return: The input points for training. - :rtype: dict - """ - - to_return = {} - if hasattr(self, "train_dataset") and self.train_dataset is not None: - to_return["train"] = self.train_dataset.input - if hasattr(self, "val_dataset") and self.val_dataset is not None: - to_return["val"] = self.val_dataset.input - if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return["test"] = self.test_dataset.input + to_return = [] + for condition_name, condition in batch.items(): + to_return.append((condition_name, condition.to(device))) return to_return diff --git a/pina/_src/data/dummy_dataloader.py b/pina/_src/data/dummy_dataloader.py new file mode 100644 index 000000000..c236e9d30 --- /dev/null +++ b/pina/_src/data/dummy_dataloader.py @@ -0,0 +1,62 @@ +""" +Module containing the ``DummyDataloader`` class +""" + +import torch + + +class DummyDataloader: + """ + A dummy dataloader that returns the entire dataset in a single batch. This + is used when the batch size is ``None``. It supports both distributed and + non-distributed environments. In a distributed environment, it divides the + dataset across processes using the rank and world size, fetching only the + portion of data corresponding to the current process. In a non-distributed + environment, it fetches the entire dataset. + """ + + def __init__(self, dataset): + """ + Prepare a dataloader object that returns the entire dataset in a single + batch. Depending on the number of GPUs, the dataset is managed + as follows: + + - **Distributed Environment** (multiple GPUs): Divides dataset across + processes using the rank and world size. Fetches only portion of + data corresponding to the current process. + - **Non-Distributed Environment** (single GPU): Fetches the entire + dataset. + + :param PinaDataset dataset: The dataset object to be processed. + + .. note:: + This dataloader is used when the batch size is ``None``. + """ + + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + if len(dataset) < world_size: + raise RuntimeError( + "Dimension of the dataset smaller than world size." + " Increase the size of the partition or use a single GPU" + ) + idx, i = [], rank + while i < len(dataset): + idx.append(i) + i += world_size + self.dataset = dataset.fetch_from_idx_list(idx).to_batch() + else: + self.dataset = dataset.get_all_data().to_batch() + + def __iter__(self): + return self + + def __len__(self): + return 1 + + def __next__(self): + return self.dataset diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index cfaeb5bec..b781c8067 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -11,6 +11,7 @@ ) from pina._src.core.label_tensor import LabelTensor from pina._src.core.utils import merge_tensors, custom_warning_format +from pina._src.condition.condition import Condition class AbstractProblem(metaclass=ABCMeta): @@ -318,34 +319,49 @@ def add_points(self, new_points_dict): [self.discretised_domains[k], v] ) - def collect_data(self): + def move_discretisation_into_conditions(self): """ - Aggregate data from the problem's conditions into a single dictionary. + Move the discretised domains into their corresponding conditions. """ - data = {} - # Iterate over the conditions and collect data - for condition_name in self.conditions: - condition = self.conditions[condition_name] - # Check if the condition has an domain attribute - if hasattr(condition, "domain"): - # Only store the discretisation points if the domain is - # in the dictionary - if condition.domain in self.discretised_domains: - samples = self.discretised_domains[condition.domain][ - self.input_variables - ] - data[condition_name] = { - "input": samples, - "equation": condition.equation, - } - else: - # If the condition does not have a domain attribute, store - # the input and target points - keys = condition.__slots__ - values = [ - getattr(condition, name) - for name in keys - if getattr(condition, name) is not None - ] - data[condition_name] = dict(zip(keys, values)) - self._collected_data = data + + for name, cond in self.conditions.items(): + if hasattr(cond, "domain"): + domain = cond.domain + self.conditions[name] = Condition( + input=self.discretised_domains[cond.domain], + equation=cond.equation, + ) + self.conditions[name].domain = domain + self.conditions[name].problem = self + + # def collect_data(self): + # """ + # Aggregate data from the problem's conditions into a single dictionary. + # """ + # data = {} + # # Iterate over the conditions and collect data + # for condition_name in self.conditions: + # condition = self.conditions[condition_name] + # # Check if the condition has an domain attribute + # if hasattr(condition, "domain"): + # # Only store the discretisation points if the domain is + # # in the dictionary + # if condition.domain in self.discretised_domains: + # samples = self.discretised_domains[condition.domain][ + # self.input_variables + # ] + # data[condition_name] = { + # "input": samples, + # "equation": condition.equation, + # } + # else: + # # If the condition does not have a domain attribute, store + # # the input and target points + # keys = condition.__slots__ + # values = [ + # getattr(condition, name) + # for name in keys + # if getattr(condition, name) is not None + # ] + # data[condition_name] = dict(zip(keys, values)) + # self._collected_data = data From 975f0ef0d4a3a48d12214217eef171be8fa1b50b Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 13 Feb 2026 09:45:25 +0100 Subject: [PATCH 2/2] remove useless code in abstract_problem.py --- pina/_src/problem/abstract_problem.py | 88 ++++++--------------------- pina/data/__init__.py | 18 ------ 2 files changed, 19 insertions(+), 87 deletions(-) diff --git a/pina/_src/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py index b781c8067..cc2b9e042 100644 --- a/pina/_src/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -43,43 +43,6 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - self._collected_data = {} - - @property - def collected_data(self): - """ - Return the collected data from the problem's conditions. If some domains - are not sampled, they will not be returned by collected data. - - :return: The collected data. Keys are condition names, and values are - dictionaries containing the input points and the corresponding - equations or target points. - :rtype: dict - """ - # collect data so far - self.collect_data() - # raise warning if some sample data are missing - if not self.are_all_domains_discretised: - warnings.formatwarning = custom_warning_format - warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { - "sampled" if key in self.discretised_domains - else - "not sampled"}""" - for key in self.domains - ] - ) - warnings.warn( - "Some of the domains are still not sampled. Consider calling " - "problem.discretise_domain function for all domains before " - "accessing the collected data:\n" - f"{warning_message}", - RuntimeWarning, - ) - return self._collected_data - # back compatibility 0.1 @property def input_pts(self): @@ -323,6 +286,25 @@ def move_discretisation_into_conditions(self): """ Move the discretised domains into their corresponding conditions. """ + if not self.are_all_domains_discretised: + warnings.formatwarning = custom_warning_format + warnings.filterwarnings("always", category=RuntimeWarning) + warning_message = "\n".join( + [ + f"""{" " * 13} ---> Domain {key} { + "sampled" if key in self.discretised_domains + else + "not sampled"}""" + for key in self.domains + ] + ) + warnings.warn( + "Some of the domains are still not sampled. Consider calling " + "problem.discretise_domain function for all domains before " + "accessing the collected data:\n" + f"{warning_message}", + RuntimeWarning, + ) for name, cond in self.conditions.items(): if hasattr(cond, "domain"): @@ -333,35 +315,3 @@ def move_discretisation_into_conditions(self): ) self.conditions[name].domain = domain self.conditions[name].problem = self - - # def collect_data(self): - # """ - # Aggregate data from the problem's conditions into a single dictionary. - # """ - # data = {} - # # Iterate over the conditions and collect data - # for condition_name in self.conditions: - # condition = self.conditions[condition_name] - # # Check if the condition has an domain attribute - # if hasattr(condition, "domain"): - # # Only store the discretisation points if the domain is - # # in the dictionary - # if condition.domain in self.discretised_domains: - # samples = self.discretised_domains[condition.domain][ - # self.input_variables - # ] - # data[condition_name] = { - # "input": samples, - # "equation": condition.equation, - # } - # else: - # # If the condition does not have a domain attribute, store - # # the input and target points - # keys = condition.__slots__ - # values = [ - # getattr(condition, name) - # for name in keys - # if getattr(condition, name) is not None - # ] - # data[condition_name] = dict(zip(keys, values)) - # self._collected_data = data diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 2ecebecdd..f274d5bd9 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -7,26 +7,8 @@ from pina._src.data.data_module import ( PinaDataModule, - PinaSampler, - DummyDataloader, - Collator, - PinaSampler, -) - -from pina._src.data.dataset import ( - PinaDataset, - PinaTensorDataset, - PinaGraphDataset, - PinaDatasetFactory, ) __all__ = [ "PinaDataModule", - "PinaDataset", - "PinaSampler", - "DummyDataloader", - "Collator", - "PinaTensorDataset", - "PinaGraphDataset", - "PinaDatasetFactory", ]