Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pina/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
"Condition",
"PinaDataModule",
"Graph",
"SolverInterface",
"MultiSolverInterface",
]

from pina._src.core.label_tensor import LabelTensor
from pina._src.core.graph import Graph
from pina._src.solver.solver import SolverInterface, MultiSolverInterface
from pina._src.core.trainer import Trainer
from pina._src.condition.condition import Condition
from pina._src.data.data_module import PinaDataModule
6 changes: 3 additions & 3 deletions pina/_src/callback/refinement/r3_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.utils import check_consistency
from pina._src.loss.loss_interface import LossInterface
from pina._src.loss.loss_interface import DualLossInterface


class R3Refinement(RefinementInterface):
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
:param int sample_every: The sampling frequency.
:param loss: The loss function to compute the residuals.
Default is :class:`~torch.nn.L1Loss`.
:type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss`
:type loss: DualLossInterface | :class:`~torch.nn.modules.loss._Loss`
:param condition_to_update: The conditions to update during the
refinement process. If None, all conditions will be updated.
Default is None.
Expand All @@ -59,7 +59,7 @@ def __init__(
# Check consistency
check_consistency(
residual_loss,
(LossInterface, torch.nn.modules.loss._Loss),
(DualLossInterface, torch.nn.modules.loss._Loss),
subclass=True,
)

Expand Down
4 changes: 1 addition & 3 deletions pina/_src/callback/refinement/refinement_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from abc import ABCMeta, abstractmethod
from lightning.pytorch import Callback
from pina._src.core.utils import check_consistency
from pina._src.solver.physics_informed_solver.pinn_interface import (
PINNInterface,
)
from pina._src.solver.pinn import PINN as PINNInterface
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, import PINN without aliasing.



class RefinementInterface(Callback, metaclass=ABCMeta):
Expand Down
97 changes: 58 additions & 39 deletions pina/_src/condition/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.condition.data_condition import DataCondition
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
Expand Down Expand Up @@ -45,20 +46,29 @@ class Condition:
represents a general physics-informed condition defined by ``input``
points and an ``equation``. The model learns to minimize the equation
residual through evaluations performed at the provided ``input``.
Supported data types for the ``input`` include
:class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`.
Supported data types for the ``input`` include :class:`~pina.graph.Graph`
or :class:`~pina.label_tensor.LabelTensor`. The class automatically
selects the appropriate implementation based on the types of the
``input``.

- :class:`~pina.condition.time_series_condition.TimeSeriesCondition`:
represents a condition designed for time series data, where the model is
trained to capture temporal dependencies and dynamics. It is defined by an
``input`` tensor of shape ``[trajectories, time_steps, *features]``
containing time series data. Supported data types for the ``input``
include class:`~pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The class automatically selects the appropriate implementation based on
the types of the ``input``.
the type of the ``input``.

- :class:`~pina.condition.data_condition.DataCondition`: represents an
unsupervised, data-driven condition defined by the ``input`` only.
The model is trained using a custom unsupervised loss determined by the
chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the
provided data during training. Optional ``conditional_variables`` can be
specified when the model depends on additional parameters.
Supported data types include :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
:class:`~torch_geometric.data.Data`. The class automatically selects the
Supported data types include :class:`~pina.label_tensor.LabelTensor`,
:class:`torch.Tensor`, :class:`~torch_geometric.data.Data`, or
:class:`~pina.graph.Graph`. The class automatically selects the
appropriate implementation based on the type of the ``input``.

.. note::
Expand All @@ -80,20 +90,32 @@ class Condition:
>>> # Example of InputEquationCondition signature
>>> condition = Condition(input=input, equation=equation)

>>> # Example of TimeSeriesCondition signature
>>> condition = Condition(
... input=input, n_windows=n_windows, unroll_length=unroll_length
... )

>>> # Example of DataCondition signature
>>> condition = Condition(input=data, conditional_variables=cond_vars)
"""

# Combine all possible keyword arguments from the different Condition types
available_kwargs = list(
set(
InputTargetCondition.__fields__
+ InputEquationCondition.__fields__
+ DomainEquationCondition.__fields__
+ DataCondition.__fields__
)
# Internal specifications for condition types, used for dispatching
# Each tuple contains: (condition class, required kwargs, optional kwargs)
_SPECS = (
(InputTargetCondition, {"input", "target"}, set()),
(InputEquationCondition, {"input", "equation"}, set()),
(DomainEquationCondition, {"domain", "equation"}, set()),
(DataCondition, {"input"}, {"conditional_variables"}),
(
TimeSeriesCondition,
{"input", "n_windows", "unroll_length"},
{"randomize"},
),
)

# Compute the set of all available keyword arguments (optional + required)
available_kwargs = sorted(set().union(*(rq | op for _, rq, op in _SPECS)))

def __new__(cls, *args, **kwargs):
"""
Instantiate the appropriate :class:`Condition` object based on the
Expand All @@ -103,38 +125,35 @@ def __new__(cls, *args, **kwargs):
:param dict kwargs: The keyword arguments corresponding to the
parameters of the specific :class:`Condition` type to instantiate.
:raises ValueError: If unexpected positional arguments are provided.
:raises ValueError: If the keyword arguments are invalid.
:raises ValueError: If the keyword arguments do not match any valid
signature for the available condition types.
:return: The appropriate :class:`Condition` object.
:rtype: ConditionInterface
"""
# Check keyword arguments
if len(args) != 0:
# Ensure no positional arguments are provided
if args:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.available_kwargs}."
"Condition takes only keyword arguments. "
f"Available arguments are: {cls.available_kwargs}."
)

# Class specialization based on keyword arguments
sorted_keys = sorted(kwargs.keys())

# Input - Target Condition
if sorted_keys == sorted(InputTargetCondition.__fields__):
return InputTargetCondition(**kwargs)
# Iterate through the specifications to find a matching condition type
for condition_cls, required, optional in cls._SPECS:

# Input - Equation Condition
if sorted_keys == sorted(InputEquationCondition.__fields__):
return InputEquationCondition(**kwargs)
# Find allowed keys for condition type
allowed = required | optional

# Domain - Equation Condition
if sorted_keys == sorted(DomainEquationCondition.__fields__):
return DomainEquationCondition(**kwargs)
# Check if the provided keys match the required and optional keys
if required <= set(kwargs) <= allowed:
return condition_cls(**kwargs)

# Data Condition
if (
sorted_keys == sorted(DataCondition.__fields__)
or sorted_keys[0] == DataCondition.__fields__[0]
):
return DataCondition(**kwargs)
# If no valid signature is found, prepare a list of valid signatures
valid_signatures = [
sorted(required | optional) for _, required, optional in cls._SPECS
]

# Invalid keyword arguments
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# If no valid signature is found, raise an error
raise ValueError(
f"Invalid keyword arguments {sorted(set(kwargs))}. "
f"Valid signatures are: {valid_signatures}."
)
25 changes: 25 additions & 0 deletions pina/_src/condition/condition_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,31 @@ def create_dataloader(
:rtype: torch.utils.data.DataLoader
"""

@abstractmethod
def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:return: The non-aggregated residual tensor.
:rtype: torch.Tensor | LabelTensor
"""

@abstractmethod
def switch_dataloader_fn(self, create_dataloader_fn):
"""
Expand Down
26 changes: 26 additions & 0 deletions pina/_src/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ def store_data(self, **kwargs):

return _DataManager(**data_dict)

def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:return: The non-aggregated residual tensor.
:rtype: torch.Tensor | LabelTensor
"""
output_ = solver.forward(batch["input"])
return loss(output_, torch.zeros_like(output_))

@property
def conditional_variables(self):
"""
Expand Down
32 changes: 32 additions & 0 deletions pina/_src/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ def store_data(self, **kwargs):
setattr(self, "domain", kwargs.get("domain"))
setattr(self, "equation", kwargs.get("equation"))

def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:raises NotImplementedError: Always raised since any domain-equation
condition is transformed into an input-equation condition before
evaluation, and the residual is computed using the input-equation
condition's evaluation method.
"""
raise NotImplementedError(
"Domain-equation conditions are transformed into input-equation "
"conditions before evaluation, and the residual is computed using "
"the input-equation condition's evaluation method. Therefore, the "
"evaluate method is not implemented for domain-equation conditions."
)

@property
def equation(self):
"""
Expand Down
27 changes: 27 additions & 0 deletions pina/_src/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,30 @@ def equation(self, value):
# Check consistency
check_consistency(value, self._avail_equation_cls)
self._equation = value

def evaluate(self, batch, solver, _):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param _: Placeholder argument (not used).
:return: The non-aggregated residual tensor.
:rtype: LabelTensor
"""
samples = batch["input"].requires_grad_(True)
return self.equation.residual(
samples, solver.forward(samples), solver._params
)
25 changes: 25 additions & 0 deletions pina/_src/condition/input_target_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,31 @@ def store_data(self, **kwargs):
"""
return _DataManager(**kwargs)

def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:return: The non-aggregated residual tensor.
:rtype: torch.Tensor | LabelTensor
"""
return loss(solver.forward(batch["input"]), batch["target"])

@property
def input(self):
"""
Expand Down
Loading
Loading