diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 64d88bc8b..7d992d1ca 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -82,6 +82,8 @@ Solvers DeepEnsembleSupervisedSolver ReducedOrderModelSolver GAROM + AutoregressiveSolverInterface + AutoregressiveSolver Models diff --git a/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst new file mode 100644 index 000000000..4cde8d1b9 --- /dev/null +++ b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst @@ -0,0 +1,7 @@ +Autoregressive Solver +====================== +.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver + +.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver.AutoregressiveSolver + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst new file mode 100644 index 000000000..516409bd1 --- /dev/null +++ b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst @@ -0,0 +1,7 @@ +Autoregressive Solver Interface +================================= +.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver_interface + +.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver_interface.AutoregressiveSolverInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/_src/solver/autoregressive_solver/__init__.py b/pina/_src/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..1c8630b7e --- /dev/null +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,396 @@ +import torch +from pina._src.solver.autoregressive_solver.autoregressive_solver_interface import ( + AutoregressiveSolverInterface, +) +from pina._src.solver.solver import SingleSolverInterface +from pina._src.loss.loss_interface import LossInterface +from pina._src.core.utils import check_consistency + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + r""" + The autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights that down-weight later predictions + to stabilize training. + """ + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + reset_weights_at_epoch_start=True, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The neural network model to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is ``None``. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. Default is ``False``. + :param bool reset_weights_at_epoch_start: If ``True``, the running + averages used for adaptive weighting are reset at the start of each + epoch. Setting this parameter to ``False`` can improve training + stability, especially when data are scarce. Default is ``True``. + :raise ValueError: If the provided loss function is not compatible. + """ + super().__init__( + problem=problem, + model=model, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + # Check consistency + loss = loss or torch.nn.MSELoss() + check_consistency( + loss, (LossInterface, torch.nn.modules.loss._Loss), subclass=False + ) + + # Initialization + self._loss_fn = loss + self.reset_weights_at_epoch_start = reset_weights_at_epoch_start + self._running_avg = {} + self._step_count = {} + + def on_train_epoch_start(self): + """ + Clean up running averages at the start of each epoch if + ``reset_weights_at_epoch_start`` is True. + """ + if self.reset_weights_at_epoch_start: + self._running_avg.clear() + self._step_count.clear() + + def optimization_cycle(self, batch): + """ + The optimization cycle for autoregressive solvers. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :return: The losses computed for all conditions in the batch. + :rtype: dict + """ + # Store losses for each condition in the batch + condition_loss = {} + + # Loop through each condition and compute the autoregressive loss + for condition_name, points in batch: + # TODO: remove setting once AutoregressiveCondition is implemented + # TODO: pass a temporal weighting schema in the __init__ + if hasattr(self.problem.conditions[condition_name], "settings"): + settings = self.problem.conditions[condition_name].settings + eps = settings.get("eps", None) + kwargs = settings.get("kwargs", {}) + else: + eps = None + kwargs = {} + + loss = self.loss_autoregressive( + points["input"], + condition_name=condition_name, + eps=eps, + **kwargs, + ) + condition_loss[condition_name] = loss + return condition_loss + + def loss_autoregressive( + self, + input, + condition_name, + eps=None, + aggregation_strategy=None, + **kwargs, + ): + """ + Compute the loss for each autoregressive condition. + + :param input: The input tensor containing unroll windows. + :type input: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for loss computation. + :raise ValueError: If ``input`` has less than 4 dimensions. + :return: The scalar loss value for the given batch. + :rtype: torch.Tensor | LabelTensor + """ + # Check input dimensionality + if input.dim() < 4: + raise ValueError( + "The provided input tensor must have at least 4 dimensions:" + " [trajectories, windows, time_steps, *features]." + f" Got shape {input.shape}." + ) + + # Initialize current state and loss list + current_state = input[:, :, 0, ...] + losses = [] + + # Iterate through the unroll window and compute the loss for each step + for step in range(1, input.shape[2]): + + # Predict + processed_input = self.preprocess_step(current_state, **kwargs) + output = self.forward(processed_input) + predicted_state = self.postprocess_step(output, **kwargs) + + # Compute step loss + target_state = input[:, :, step, ...] + step_loss = self._loss_fn(predicted_state, target_state, **kwargs) + losses.append(step_loss) + + # Update current state for the next step + current_state = predicted_state + + # Stack step losses into a tensor of shape [time_steps - 1] + step_losses = torch.stack(losses) + + # Compute adaptive weights based on running averages of step losses + with torch.no_grad(): + condition_name = condition_name or "default" + weights = self._get_weights(condition_name, step_losses, eps) + + # Aggregate the weighted step losses into a single scalar loss value + if aggregation_strategy is None: + aggregation_strategy = torch.mean + + return aggregation_strategy(step_losses * weights) + + def preprocess_step(self, current_state, **kwargs): + """ + Pre-process the current state before passing it to the model's forward. + + :param current_state: The current state to be preprocessed. + :type current_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for pre-processing. + :return: The preprocessed state for the given step. + :rtype: torch.Tensor | LabelTensor + """ + return current_state + + def postprocess_step(self, predicted_state, **kwargs): + """ + Post-process the state predicted by the model. + + :param predicted_state: The predicted state tensor from the model. + :type predicted_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for post-processing. + :return: The post-processed predicted state tensor. + :rtype: torch.Tensor | LabelTensor + """ + return predicted_state + + def _get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + + :param str condition_name: The name of the current condition. + :param torch.Tensor step_losses: The tensor of per-step losses. + :param float eps: The weighting parameter. + :return: The weights tensor. + :rtype: torch.Tensor + """ + # Determine the key for caching based on the condition name + key = condition_name or "default" + + # Initialize the key if not in the running averages. + if key not in self._running_avg: + self._running_avg[key] = step_losses.detach().clone() + self._step_count[key] = 1 + + # Update running averages and counts + else: + self._step_count[key] += 1 + value = step_losses.detach() - self._running_avg[key] + self._running_avg[key] += value / self._step_count[key] + + return self._compute_adaptive_weights(self._running_avg[key], eps) + + def _compute_adaptive_weights(self, step_losses, eps): + """ + Compute temporal adaptive weights. + + :param torch.Tensor step_losses: The tensor of per-step losses. + :param float eps: The weighting parameter. + :return: The weights tensor. + :rtype: torch.Tensor + """ + # If eps is None, return uniform weights + if eps is None: + return torch.ones_like(step_losses) + + # Compute cumulative loss and apply exponential weighting + cumulative_loss = -eps * torch.cumsum(step_losses, dim=0) + + return torch.exp(cumulative_loss) + + def predict(self, initial_state, n_steps, **kwargs): + """ + Generate predictions by recursively calling the model's forward. + + :param initial_state: The initial state from which to start prediction. + The initial state must be of shape ``[trajectories, 1, *features]``. + :type initial_state: torch.Tensor | LabelTensor + :param int n_steps: The number of autoregressive steps to predict. + :param dict kwargs: Additional keyword arguments. + :raise ValueError: If the provided initial_state tensor has less than 3 + dimensions. + :return: The predicted trajectory, including the initial state. It has + shape ``[trajectories, n_steps + 1, *features]``, where the first + step corresponds to the initial state. + :rtype: torch.Tensor | LabelTensor + """ + # Set model to evaluation mode for prediction + self.eval() + + # Check intial state dimensionality + if initial_state.dim() < 3: + raise ValueError( + "The provided initial_state tensor must have at least 3" + "dimensions: [trajectories, time_steps, *features]." + f" Got shape {initial_state.shape}." + ) + + # Initialize the list of predictions with the initial state + predictions = [initial_state] + + # Generate predictions recursively for n_steps + with torch.no_grad(): + for _ in range(n_steps): + input = self.preprocess_step(predictions[-1], **kwargs) + output = self.forward(input) + next_state = self.postprocess_step(output, **kwargs) + predictions.append(next_state) + + return torch.stack(predictions, dim=2) + + # TODO: integrate in the Autoregressive Condition once implemented + @staticmethod + def unroll(data, unroll_length, n_unrolls=None, randomize=True): + """ + Create unrolling time windows from temporal data. + + This function takes as input a tensor of shape + ``[trajectories, time_steps, *features]`` and produces a tensor of shape + ``[trajectories, windows, unroll_length, *features]``. + Each window contains a sequence of subsequent states used for computing + the multi-step loss during training. + + :param data: The temporal data tensor to be unrolled. + :type data: torch.Tensor | LabelTensor + :param int unroll_length: The number of time steps in each window. + :param int n_unrolls: The maximum number of windows to return. + If ``None``, all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly + permuted before applying ``n_unrolls``. Default is ``True``. + :raise ValueError: If the input ``data`` has less than 3 dimensions. + :raise ValueError: If ``unroll_length`` is greater or equal to the + number of time steps in ``data``. + :return: A tensor of unrolled windows. + :rtype: torch.Tensor | LabelTensor + """ + # Check input dimensionality + if data.dim() < 3: + raise ValueError( + "The provided data tensor must have at least 3 dimensions:" + " [trajectories, time_steps, *features]." + f" Got shape {data.shape}." + ) + + # Determine valid starting indices for unroll windows + start_idx = AutoregressiveSolver._get_start_idx( + n_steps=data.shape[1], + unroll_length=unroll_length, + n_unrolls=n_unrolls, + randomize=randomize, + ) + + # Create unroll windows by slicing the data tensor at starting indices + windows = [data[:, s : s + unroll_length, ...] for s in start_idx] + + return torch.stack(windows, dim=1) + + @staticmethod + def _get_start_idx(n_steps, unroll_length, n_unrolls=None, randomize=True): + """ + Determine starting indices for unroll windows. + + :param int n_steps: The total number of time steps in the data. + :param int unroll_length: The number of time steps in each window. + :param int n_unrolls: The maximum number of windows to return. + If ``None``, all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly + permuted before applying ``n_unrolls``. Default is ``True``. + :raise ValueError: If ``unroll_length`` is greater or equal to the + number of time steps in ``data``. + :return: A tensor of starting indices for unroll windows. + :rtype: torch.Tensor + """ + # Calculate the last valid starting index for unroll windows + last_idx = n_steps - unroll_length + + # Raise error if no valid windows can be created + if last_idx < 0: + raise ValueError( + f"Cannot create unroll windows: unroll_length ({unroll_length})" + " cannot be greater or equal to the number of time_steps" + f" ({n_steps})." + ) + + # Generate ordered starting indices for unroll windows + indices = torch.arange(last_idx + 1) + + # Permute indices if randomization is enabled + if randomize: + indices = indices[torch.randperm(len(indices))] + + # Limit the number of windows if n_unrolls is specified + if n_unrolls is not None and n_unrolls < len(indices): + indices = indices[:n_unrolls] + + return indices + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..7029995fd --- /dev/null +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,82 @@ +"""Module for the Autoregressive Solver Interface.""" + +from abc import abstractmethod +from pina._src.condition.data_condition import DataCondition +from pina._src.solver.solver import SolverInterface + + +class AutoregressiveSolverInterface(SolverInterface): + # TODO: fix once the AutoregressiveCondition is implemented. + """ + Abstract interface for all autoregressive solvers. + + Any solver implementing this interface is expected to be designed to learn + dynamical systems in an autoregressive manner. The solver should handle + conditions of type :class:`~pina.condition.data_condition.DataCondition`. + """ + + accepted_conditions_types = (DataCondition,) + + @abstractmethod + def preprocess_step(self, current_state, **kwargs): + """ + Pre-process the current state before passing it to the model's forward. + + :param current_state: The current state to be preprocessed. + :type current_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for pre-processing. + :return: The preprocessed state for the given step. + :rtype: torch.Tensor | LabelTensor + """ + + @abstractmethod + def postprocess_step(self, predicted_state, **kwargs): + """ + Post-process the state predicted by the model. + + :param predicted_state: The predicted state tensor from the model. + :type predicted_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for post-processing. + :return: The post-processed predicted state tensor. + :rtype: torch.Tensor | LabelTensor + """ + + # TODO: remove once the AutoregressiveCondition is implemented. + @abstractmethod + def loss_autoregressive(self, input, **kwargs): + """ + Compute the loss for each autoregressive condition. + + :param input: The input tensor containing unroll windows. + :type input: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for loss computation. + :return: The scalar loss value for the given batch. + :rtype: torch.Tensor | LabelTensor + """ + + @abstractmethod + def predict(self, starting_value, num_steps, **kwargs): + """ + Generate predictions by recursively applying the model. + + :param starting_value: The initial state from which to start prediction. + The initial state must be of shape ``[trajectories, 1, features]``, + where the trajectory dimension can be used for batching. + :type starting_value: torch.Tensor | LabelTensor + :param int num_steps: The number of autoregressive steps to predict. + :param dict kwargs: Additional keyword arguments. + :return: The predicted trajectory, including the initial state. It has + shape ``[trajectories, num_steps + 1, features]``, where the first + step corresponds to the initial state. + :rtype: torch.Tensor | LabelTensor + """ + + @property + @abstractmethod + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index a93914099..619e59d04 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -27,6 +27,8 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", + "AutoregressiveSolverInterface", ] from pina._src.solver.solver import ( @@ -64,3 +66,8 @@ ) from pina._src.solver.garom import GAROM + +from pina._src.solver.autoregressive_solver.autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..29e5628b9 --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,302 @@ +import pytest +import torch + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +# Set random seed for reproducibility +torch.manual_seed(42) +NUM_TIMESERIES = 2 + + +def _make_series(T, F): + torch.manual_seed(42) + num_t_series = NUM_TIMESERIES + y = torch.zeros(num_t_series, T, F) + y[0] = torch.rand(F) + y[1] = torch.rand(F) + for t in range(T - 1): + y[:, t + 1] = 0.95 * y[:, t] + return y + + +### END-TO-END ############################################################################# + + +@pytest.fixture +def y_data_large(): + return _make_series(T=100, F=15) + + +class MinimalModel(torch.nn.Module): + """ + Minimal model that applies a linear transformation. + Used for end-to-end testing. Since the problem dynamic is linear, this model + should in principle learn the correct transformation. + """ + + def __init__(self): + super().__init__() + self.layers = torch.nn.Linear(15, 15, bias=False) + + def forward(self, x): + return x + self.layers(x) + + +def test_end_to_end(y_data_large): + """ + End-to-end test with MinimalModel. + This test performs a 3-phase training with increasing unroll lengths, shows how to use + the AutoregressiveSolver with curriculum learning + """ + + # AbstratProblem with empty conditions to be filled later + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {} + + problem = Problem() + + solver = AutoregressiveSolver( + problem=problem, + model=MinimalModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.015), + ) + # PHASE1: train with 'short' condition only + y_short = AutoregressiveSolver.unroll( + y_data_large, unroll_length=4, num_unrolls=20, randomize=False + ) + problem.conditions["short"] = DataCondition(input=y_short) + problem.conditions["short"].settings = {"eps": 0.1} + trainer1 = Trainer( + solver, max_epochs=300, accelerator="cpu", enable_model_summary=False + ) + trainer1.train() + + # PHASE2: train with 'medium' condition only + y_medium = AutoregressiveSolver.unroll( + y_data_large, unroll_length=10, num_unrolls=15, randomize=False + ) + problem.conditions.clear() + problem.conditions["medium"] = DataCondition(input=y_medium) + problem.conditions["medium"].settings = {"eps": 0.2} + trainer2 = Trainer( + solver, max_epochs=1500, accelerator="cpu", enable_model_summary=False + ) + trainer2.train() + + # PHASE3: train with 'long' condition only + y_long = AutoregressiveSolver.unroll( + y_data_large, unroll_length=20, num_unrolls=10, randomize=False + ) + problem.conditions.clear() + problem.conditions["long"] = DataCondition(input=y_long) + problem.conditions["long"].settings = {"eps": 0.25} + trainer3 = Trainer( + solver, max_epochs=4000, accelerator="cpu", enable_model_summary=False + ) + trainer3.train() + + test_start_idx = 50 + num_predictions = 49 + start_state = y_data_large[:, test_start_idx, :].unsqueeze(1) + ground_truth = y_data_large[ + :, test_start_idx : test_start_idx + num_predictions + 1, : + ] + prediction = solver.predict(start_state, num_steps=num_predictions) + + # prediction has shape [B, 1, num_predictions+1, F] + assert prediction.squeeze(1).shape == ground_truth.shape + total_mse = torch.nn.functional.mse_loss( + prediction.squeeze(1)[:, 1:, :], ground_truth[:, 1:, :] + ) + assert total_mse < 1e-5 + + +# ### UNIT TESTS ############################################################################# + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + +@pytest.fixture +def y_data(): + return _make_series(T=10, F=3) + + +class ExactModel(torch.nn.Module): + """ + This model implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x + return next_state + 0.0 * self.dummy_param + + +def test_unroll_shape_and_content(y_data): + B, T, F = y_data.shape + + w = AutoregressiveSolver.unroll( + y_data, unroll_length=4, num_unrolls=2, randomize=False + ) + # new shape: (num_unroll_starts, B, Twin, F) + assert w.shape == (2, B, 5, F) + + # windows for first unroll start (both series) + assert torch.allclose(w[0, 0], y_data[0, 0:5, :]) + assert torch.allclose(w[0, 1], y_data[1, 0:5, :]) + + # windows for second unroll start (both series) + assert torch.allclose(w[1, 0], y_data[0, 1:6, :]) + assert torch.allclose(w[1, 1], y_data[1, 1:6, :]) + + +def test_decide_starting_indices_edge_cases(y_data): + n_steps = y_data.shape[1] + # print("n_steps is ",n_steps) + idx = AutoregressiveSolver.decide_starting_indices( + n_steps, unroll_length=3, num_unrolls=None, randomize=False + ) + # T=10, Twin=4 => last_start=6 => 0..6 + assert torch.equal(idx, torch.arange(7)) + + idx_empty = AutoregressiveSolver.decide_starting_indices( + n_steps, + unroll_length=NUM_TIMESTEPS + 5, + num_unrolls=None, + randomize=False, + ) + assert idx_empty.numel() == 0 + + +def test_exact_model(y_data): + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=4, randomize=False + ) + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition": DataCondition(input=windows), + } + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), + ) + + loss = solver.loss_data( + windows, + eps=None, + aggregation_strategy=torch.sum, + condition_name="data_condition", + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + +def test_predict_matches_ground_truth(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) + + pred = solver.predict( + y_data[:, 0, :].unsqueeze(1), num_steps=NUM_TIMESTEPS - 1 + ) + # pred shape [B,1,T,F] + assert pred.squeeze(1).shape == y_data.shape + assert torch.allclose(pred.squeeze(1), y_data, atol=1e-6) + + +def test_adaptive_weights_are_finite_and_normalized(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) + + step_losses = torch.tensor([1.0, 2.0, 3.0]) + + w1 = solver._compute_adaptive_weights(step_losses, eps=1.0) + assert torch.isfinite(w1).all() + assert torch.isclose(w1.sum(), torch.tensor(1.0), atol=1e-6) + + w2 = solver._compute_adaptive_weights(step_losses, eps=None) + assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6) + + w3 = solver.get_weights("data", step_losses, eps=1.0) + assert torch.isfinite(w3).all() + assert torch.isclose(w3.sum(), torch.tensor(1.0), atol=1e-6) + + +def test_trainer_integration_one_epoch(y_data): + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=None, randomize=False + ) + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=windows)} + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=1e-2), + ) + + trainer = Trainer( + solver=solver, + max_epochs=1, + accelerator="cpu", + ) + trainer.train() + + with torch.no_grad(): + loss = solver.loss_data( + windows[:4], + eps=None, + aggregation_strategy=torch.sum, + condition_name="data", + ) + assert torch.isfinite(loss) + + +def test_weight_cache_resets_on_epoch_start(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + reset_weighting_at_epoch_start=True, + ) + + step_losses = torch.tensor([1.0, 2.0, 3.0]) + + _ = solver.get_weights("data", step_losses, eps=1.0) + assert "data" in solver._running_avg_step_losses + assert "data" in solver._running_step_counts + + solver.on_train_epoch_start() + + assert solver._running_avg_step_losses == {} + assert solver._running_step_counts == {}