Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions isaaclab_arena/assets/dummy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from isaaclab_arena.relations.relations import AtPosition, Relation, RelationBase
from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox
from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, quaternion_to_90_deg_z_quarters
from isaaclab_arena.utils.pose import Pose


Expand Down Expand Up @@ -43,20 +43,14 @@ def get_bounding_box(self) -> AxisAlignedBoundingBox:
return self.bounding_box

def get_world_bounding_box(self) -> AxisAlignedBoundingBox:
"""Get bounding box in world coordinates (local bbox + position offset)."""
pos = self.initial_pose.position_xyz if self.initial_pose else (0, 0, 0)
return AxisAlignedBoundingBox(
min_point=(
self.bounding_box.min_point[0] + pos[0],
self.bounding_box.min_point[1] + pos[1],
self.bounding_box.min_point[2] + pos[2],
),
max_point=(
self.bounding_box.max_point[0] + pos[0],
self.bounding_box.max_point[1] + pos[1],
self.bounding_box.max_point[2] + pos[2],
),
)
"""Get bounding box in world coordinates (local bbox rotated and translated).

Only 90° rotations around Z axis are supported.
"""
if self.initial_pose is None:
return self.bounding_box
quarters = quaternion_to_90_deg_z_quarters(self.initial_pose.rotation_wxyz)
return self.bounding_box.rotated_90_around_z(quarters).translated(self.initial_pose.position_xyz)

def get_corners_aabb(self, pos: torch.Tensor) -> torch.Tensor:
return self.bounding_box.get_corners_at(pos)
Expand Down
16 changes: 12 additions & 4 deletions isaaclab_arena/assets/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from isaaclab_arena.assets.object_utils import detect_object_type
from isaaclab_arena.relations.relations import RelationBase
from isaaclab_arena.terms.events import set_object_pose
from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox
from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, quaternion_to_90_deg_z_quarters
from isaaclab_arena.utils.pose import Pose, PoseRange
from isaaclab_arena.utils.usd.rigid_bodies import find_shallowest_rigid_body
from isaaclab_arena.utils.usd_helpers import compute_local_bounding_box_from_usd, has_light, open_stage
Expand Down Expand Up @@ -68,10 +68,18 @@ def get_bounding_box(self) -> AxisAlignedBoundingBox:
return self.bounding_box

def get_world_bounding_box(self) -> AxisAlignedBoundingBox:
"""Get bounding box in world coordinates (local bbox + position offset)."""
"""Get bounding box in world coordinates (local bbox rotated and translated).

Only 90° rotations around Z axis are supported. An assertion error is raised
for any other rotation. If initial_pose is a PoseRange (not a fixed Pose),
returns the local bounding box without transformation.
"""
local_bbox = self.get_bounding_box()
pos = self.initial_pose.position_xyz if self.initial_pose else (0, 0, 0)
return local_bbox.translated(pos)
if self.initial_pose is None:
return local_bbox
assert isinstance(self.initial_pose, Pose), "Only Pose is supported for world bounding box"
quarters = quaternion_to_90_deg_z_quarters(self.initial_pose.rotation_wxyz)
return local_bbox.rotated_90_around_z(quarters).translated(self.initial_pose.position_xyz)

def get_corners(self, pos: torch.Tensor) -> torch.Tensor:
assert self.usd_path is not None
Expand Down
12 changes: 9 additions & 3 deletions isaaclab_arena/assets/object_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from isaaclab_arena.assets.asset import Asset
from isaaclab_arena.assets.object_base import ObjectBase, ObjectType
from isaaclab_arena.relations.relations import IsAnchor, RelationBase
from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox
from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, quaternion_to_90_deg_z_quarters
from isaaclab_arena.utils.pose import Pose
from isaaclab_arena.utils.usd_helpers import compute_local_bounding_box_from_prim, open_stage
from isaaclab_arena.utils.usd_pose_helpers import get_prim_pose_in_default_prim_frame
Expand Down Expand Up @@ -74,8 +74,14 @@ def get_bounding_box(self) -> AxisAlignedBoundingBox:
return self._bounding_box

def get_world_bounding_box(self) -> AxisAlignedBoundingBox:
"""Get bounding box in world coordinates (local bbox + world position)."""
return self.get_bounding_box().translated(self.get_initial_pose().position_xyz)
"""Get bounding box in world coordinates (local bbox rotated and translated).

Only 90° rotations around Z axis are supported for AxisAlignedBoundingBox.
An assertion error is raised for any other rotation.
"""
pose = self.get_initial_pose()
quarters = quaternion_to_90_deg_z_quarters(pose.rotation_wxyz)
return self.get_bounding_box().rotated_90_around_z(quarters).translated(pose.position_xyz)

def get_contact_sensor_cfg(self, contact_against_prim_paths: list[str] | None = None) -> ContactSensorCfg:
# NOTE(alexmillane): Right now this requires that the object
Expand Down
79 changes: 33 additions & 46 deletions isaaclab_arena/relations/relation_loss_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class Direction(IntEnum):

@dataclass(frozen=True)
class SideConfig:
"""Configuration for computing NextTo loss for a given side.
"""Configuration for computing NextTo loss for a given axis direction.

Attributes:
primary_axis: Axis along which child is placed (X or Y).
direction: POSITIVE if child should be in positive direction from parent (RIGHT, BACK),
NEGATIVE if child should be in negative direction (LEFT, FRONT).
direction: POSITIVE if child should be in positive direction from parent,
NEGATIVE if child should be in negative direction.
"""

primary_axis: Axis
Expand All @@ -57,10 +57,10 @@ def band_axis(self) -> Axis:


SIDE_CONFIGS: dict[Side, SideConfig] = {
Side.RIGHT: SideConfig(primary_axis=Axis.X, direction=Direction.POSITIVE),
Side.LEFT: SideConfig(primary_axis=Axis.X, direction=Direction.NEGATIVE),
Side.BACK: SideConfig(primary_axis=Axis.Y, direction=Direction.POSITIVE),
Side.FRONT: SideConfig(primary_axis=Axis.Y, direction=Direction.NEGATIVE),
Side.POSITIVE_X: SideConfig(primary_axis=Axis.X, direction=Direction.POSITIVE),
Side.NEGATIVE_X: SideConfig(primary_axis=Axis.X, direction=Direction.NEGATIVE),
Side.POSITIVE_Y: SideConfig(primary_axis=Axis.Y, direction=Direction.POSITIVE),
Side.NEGATIVE_Y: SideConfig(primary_axis=Axis.Y, direction=Direction.NEGATIVE),
}


Expand Down Expand Up @@ -95,18 +95,16 @@ def compute_loss(
self,
relation: "Relation",
child_pos: torch.Tensor,
parent_pos: torch.Tensor,
child_bbox: AxisAlignedBoundingBox,
parent_bbox: AxisAlignedBoundingBox,
parent_world_bbox: AxisAlignedBoundingBox,
) -> torch.Tensor:
"""Compute the loss for a relation constraint.

Args:
relation: The relation object containing relationship metadata.
child_pos: Child object position tensor (x, y, z) in world coords.
parent_pos: Parent object position tensor (x, y, z) in world coords.
child_bbox: Child object local bounding box (extents relative to origin).
parent_bbox: Parent object local bounding box (extents relative to origin).
parent_world_bbox: Parent bounding box in world coordinates.

Returns:
Scalar loss tensor representing the constraint violation.
Expand Down Expand Up @@ -137,21 +135,18 @@ def compute_loss(
self,
relation: "NextTo",
child_pos: torch.Tensor,
parent_pos: torch.Tensor,
child_bbox: AxisAlignedBoundingBox,
parent_bbox: AxisAlignedBoundingBox,
parent_world_bbox: AxisAlignedBoundingBox,
) -> torch.Tensor:
"""Compute loss for NextTo relation.

Uses world-space extents (position + bbox.min/max) for origin-agnostic computation.
Supports all four sides: LEFT, RIGHT, FRONT, BACK.

Args:
relation: NextTo relation with side and distance attributes.
child_pos: Child object position tensor (x, y, z) in world coords.
parent_pos: Parent object position tensor (x, y, z) in world coords.
child_bbox: Child object local bounding box.
parent_bbox: Parent object local bounding box.
parent_world_bbox: Parent bounding box in world coordinates.

Returns:
Weighted loss tensor.
Expand All @@ -160,13 +155,13 @@ def compute_loss(
distance = relation.distance_m
assert distance >= 0.0, f"NextTo distance must be non-negative, got {distance}"

# Select parent edge and child offset based on direction
# Parent world extents from the world bounding box
if cfg.direction == Direction.POSITIVE:
parent_edge = parent_pos[cfg.primary_axis] + parent_bbox.max_point[cfg.primary_axis]
parent_edge = parent_world_bbox.max_point[cfg.primary_axis]
child_offset = child_bbox.min_point[cfg.primary_axis]
penalty_side = "less"
else:
parent_edge = parent_pos[cfg.primary_axis] + parent_bbox.min_point[cfg.primary_axis]
parent_edge = parent_world_bbox.min_point[cfg.primary_axis]
child_offset = child_bbox.max_point[cfg.primary_axis]
penalty_side = "greater"

Expand All @@ -179,9 +174,8 @@ def compute_loss(
)

# 2. Band loss: child's footprint must be within parent's extent on perpendicular axis
# Compute valid position range such that child's entire footprint stays within parent
parent_band_min = parent_pos[cfg.band_axis] + parent_bbox.min_point[cfg.band_axis]
parent_band_max = parent_pos[cfg.band_axis] + parent_bbox.max_point[cfg.band_axis]
parent_band_min = parent_world_bbox.min_point[cfg.band_axis]
parent_band_max = parent_world_bbox.max_point[cfg.band_axis]
valid_band_min = parent_band_min - child_bbox.min_point[cfg.band_axis]
valid_band_max = parent_band_max - child_bbox.max_point[cfg.band_axis]
band_loss = linear_band_loss(
Expand All @@ -202,17 +196,17 @@ def compute_loss(
band_axis_name = cfg.band_axis.name
print(
f" [NextTo] {relation.side.value}: child_{axis_name.lower()}="
f"{child_pos[cfg.primary_axis].item():.4f}, parent_edge={parent_edge.item():.4f},"
f"{child_pos[cfg.primary_axis].item():.4f}, parent_edge={parent_edge:.4f},"
f" loss={half_plane_loss.item():.6f}"
)
print(
f" [NextTo] {band_axis_name} band: child_{band_axis_name.lower()}="
f"{child_pos[cfg.band_axis].item():.4f}, valid_range=[{valid_band_min.item():.4f},"
f" {valid_band_max.item():.4f}], loss={band_loss.item():.6f}"
f"{child_pos[cfg.band_axis].item():.4f}, valid_range=[{valid_band_min:.4f},"
f" {valid_band_max:.4f}], loss={band_loss.item():.6f}"
)
print(
f" [NextTo] Distance: child_{axis_name.lower()}="
f"{child_pos[cfg.primary_axis].item():.4f}, target={target_pos.item():.4f},"
f"{child_pos[cfg.primary_axis].item():.4f}, target={target_pos:.4f},"
f" loss={distance_loss.item():.6f}"
)

Expand Down Expand Up @@ -243,30 +237,26 @@ def compute_loss(
self,
relation: "On",
child_pos: torch.Tensor,
parent_pos: torch.Tensor,
child_bbox: AxisAlignedBoundingBox,
parent_bbox: AxisAlignedBoundingBox,
parent_world_bbox: AxisAlignedBoundingBox,
) -> torch.Tensor:
"""Compute loss for On relation.

Uses world-space extents (position + bbox.min/max) for origin-agnostic computation.

Args:
relation: On relation with clearance_m attribute.
child_pos: Child object position tensor (x, y, z) in world coords.
parent_pos: Parent object position tensor (x, y, z) in world coords.
child_bbox: Child object local bounding box.
parent_bbox: Parent object local bounding box.
parent_world_bbox: Parent bounding box in world coordinates.

Returns:
Weighted loss tensor.
"""
# Compute parent world-space extents
parent_x_min = parent_pos[0] + parent_bbox.min_point[0]
parent_x_max = parent_pos[0] + parent_bbox.max_point[0]
parent_y_min = parent_pos[1] + parent_bbox.min_point[1]
parent_y_max = parent_pos[1] + parent_bbox.max_point[1]
parent_z_max = parent_pos[2] + parent_bbox.max_point[2] # Top surface
# Parent world-space extents from the world bounding box
parent_x_min = parent_world_bbox.min_point[0]
parent_x_max = parent_world_bbox.max_point[0]
parent_y_min = parent_world_bbox.min_point[1]
parent_y_max = parent_world_bbox.max_point[1]
parent_z_max = parent_world_bbox.max_point[2] # Top surface

# Compute valid position ranges such that child's entire footprint is within parent
# Child left edge = child_pos[0] + child_bbox.min_point[0], must be >= parent_x_min
Expand Down Expand Up @@ -298,17 +288,14 @@ def compute_loss(

if self.debug:
print(
f" [On] X: child_pos={child_pos[0].item():.4f}, valid_range=[{valid_x_min.item():.4f},"
f" {valid_x_max.item():.4f}], loss={x_band_loss.item():.6f}"
)
print(
f" [On] Y: child_pos={child_pos[1].item():.4f}, valid_range=[{valid_y_min.item():.4f},"
f" {valid_y_max.item():.4f}], loss={y_band_loss.item():.6f}"
f" [On] X: child_pos={child_pos[0].item():.4f}, valid_range=[{valid_x_min:.4f},"
f" {valid_x_max:.4f}], loss={x_band_loss.item():.6f}"
)
print(
f" [On] Z: child_pos={child_pos[2].item():.4f}, target={target_z.item():.4f},"
f" loss={z_loss.item():.6f}"
f" [On] Y: child_pos={child_pos[1].item():.4f}, valid_range=[{valid_y_min:.4f},"
f" {valid_y_max:.4f}], loss={y_band_loss.item():.6f}"
)
print(f" [On] Z: child_pos={child_pos[2].item():.4f}, target={target_z:.4f}, loss={z_loss.item():.6f}")

total_loss = x_band_loss + y_band_loss + z_loss
return relation.relation_loss_weight * total_loss
Expand Down
35 changes: 19 additions & 16 deletions isaaclab_arena/relations/relation_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,24 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) -
_print_unary_relation_debug(obj, relation, child_pos, loss)
# Handle binary relations (with parent) like On, NextTo
elif isinstance(relation, Relation):
parent_pos = state.get_position(relation.parent)
# Build parent world bbox: anchors have a known fixed pose,
# optimizable parents use the current solver position + local bbox.
parent = relation.parent
if parent in state.anchor_objects:
parent_world_bbox = parent.get_world_bounding_box()
else:
parent_pos = state.get_position(parent)
parent_world_bbox = parent.get_bounding_box().translated(
(float(parent_pos[0]), float(parent_pos[1]), float(parent_pos[2]))
)
loss = strategy.compute_loss(
relation=relation,
child_pos=child_pos,
parent_pos=parent_pos,
child_bbox=obj.get_bounding_box(),
parent_bbox=relation.parent.get_bounding_box(),
parent_world_bbox=parent_world_bbox,
)
if debug:
parent_pos = state.get_position(parent)
_print_relation_debug(obj, relation, child_pos, parent_pos, loss)
else:
raise ValueError(f"Unknown relation type: {type(relation).__name__}")
Expand Down Expand Up @@ -225,31 +234,25 @@ def _print_relation_debug(
) -> None:
"""Print debug information for a single binary relation."""
child_bbox = obj.get_bounding_box()
parent_bbox = relation.parent.get_bounding_box()
parent_world_bbox = relation.parent.get_world_bounding_box()

print(f"\n=== {obj.name} -> {type(relation).__name__}({relation.parent.name}) ===")
print(f" Child pos: ({child_pos[0].item():.4f}, {child_pos[1].item():.4f}, {child_pos[2].item():.4f})")
print(f" Child bbox: min={child_bbox.min_point}, max={child_bbox.max_point}, size={child_bbox.size}")
print(f" Parent pos: ({parent_pos[0].item():.4f}, {parent_pos[1].item():.4f}, {parent_pos[2].item():.4f})")
print(f" Parent bbox: min={parent_bbox.min_point}, max={parent_bbox.max_point}, size={parent_bbox.size}")
print(
f" Parent world bbox: min={parent_world_bbox.min_point}, max={parent_world_bbox.max_point},"
f" size={parent_world_bbox.size}"
)

# Child world extents
child_x_range = (child_pos[0].item() + child_bbox.min_point[0], child_pos[0].item() + child_bbox.max_point[0])
child_y_range = (child_pos[1].item() + child_bbox.min_point[1], child_pos[1].item() + child_bbox.max_point[1])
# Parent world extents
parent_x_range = (
parent_pos[0].item() + parent_bbox.min_point[0],
parent_pos[0].item() + parent_bbox.max_point[0],
)
parent_y_range = (
parent_pos[1].item() + parent_bbox.min_point[1],
parent_pos[1].item() + parent_bbox.max_point[1],
)

print(f" Child world X: [{child_x_range[0]:.4f}, {child_x_range[1]:.4f}]")
print(f" Child world Y: [{child_y_range[0]:.4f}, {child_y_range[1]:.4f}]")
print(f" Parent world X: [{parent_x_range[0]:.4f}, {parent_x_range[1]:.4f}]")
print(f" Parent world Y: [{parent_y_range[0]:.4f}, {parent_y_range[1]:.4f}]")
print(f" Parent world X: [{parent_world_bbox.min_point[0]:.4f}, {parent_world_bbox.max_point[0]:.4f}]")
print(f" Parent world Y: [{parent_world_bbox.min_point[1]:.4f}, {parent_world_bbox.max_point[1]:.4f}]")
print(f" Loss: {loss.item():.6f}")


Expand Down
14 changes: 7 additions & 7 deletions isaaclab_arena/relations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@


class Side(Enum):
"""Side of an object for spatial relationships."""
"""Axis direction for spatial relationships."""

FRONT = "front" # -Y
BACK = "back" # +Y
LEFT = "left" # -X
RIGHT = "right" # +X
POSITIVE_X = "positive_x" # +X
NEGATIVE_X = "negative_x" # -X
POSITIVE_Y = "positive_y" # +Y
NEGATIVE_Y = "negative_y" # -Y


class RelationBase:
Expand Down Expand Up @@ -65,14 +65,14 @@ def __init__(
parent: Object | ObjectReference,
relation_loss_weight: float = 1.0,
distance_m: float = 0.05,
side: Side = Side.RIGHT,
side: Side = Side.POSITIVE_X,
):
"""
Args:
parent: The parent asset that this object should be placed next to.
relation_loss_weight: Weight for the relationship loss function.
distance_m: Target distance from parent's boundary in meters (default: 5cm).
side: Which side to place object (default: Side.RIGHT).
side: Which axis direction to place object (default: Side.POSITIVE_X).
"""
super().__init__(parent, relation_loss_weight)
assert distance_m > 0.0, f"Distance must be positive, got {distance_m}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _create_test_objects() -> tuple[DummyObject, DummyObject, DummyObject]:

box1.add_relation(On(desk, clearance_m=0.01))
box2.add_relation(On(desk, clearance_m=0.01))
box2.add_relation(NextTo(box1, side=Side.RIGHT, distance_m=0.05))
box2.add_relation(NextTo(box1, side=Side.POSITIVE_X, distance_m=0.05))

return desk, box1, box2

Expand Down
Loading
Loading