Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions .vscode/settings.json

This file was deleted.

58 changes: 51 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
χ₀ addresses the systematic distributional shift among the human demonstration distribution ($P_\text{train}$), the inductive bias learned by the policy ($Q_\text{model}$), and the test-time execution distribution ($P_\text{test}$) through three technical modules:

- **[Model Arithmetic](#model-arithmetic)**: A weight-space merging strategy that combines models trained on different data subsets, efficiently capturing diverse knowledge without architectural complexity. **[Released]**
- **[Stage Advantage](#stage-advantage-coming-soon)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Coming Soon]**
- **[Stage Advantage](#stage-advantage)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Released]**
- **[Train-Deploy Alignment](#train-deploy-alignment-coming-soon)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Coming Soon]**

χ₀ enables two sets of dual-arm robots to collaboratively orchestrate long-horizon garment manipulation — flattening, folding, and hanging — surpassing the state-of-the-art $\pi_{0.5}$ baseline by approximately 250% in success rate, with `only 20 hours of data and 8 A100 GPUs`.
Expand All @@ -46,14 +46,15 @@ https://github.com/user-attachments/assets/3f5f0c48-ff3f-4b9b-985b-59ad0b2ea97c
- [Model Arithmetic](#model-arithmetic)
- [Workflow](#workflow)
- [Quick Start](#quick-start)
- [Stage Advantage (Coming Soon)](#stage-advantage-coming-soon)
- [Stage Advantage](#stage-advantage)
- [Train-Deploy Alignment (Coming Soon)](#train-deploy-alignment-coming-soon)
- [Citation](#licenseandcitation)
- [Troubleshooting](#troubleshooting)
- [Links and Community](#links-and-community)

## Update

- [Feb 14 2026] Release of the **Stage Advantage** module: advantage estimator training, evaluation, GT labeling, and AWBC training pipeline.
- [Feb 10 2026] Initial release of the **Model Arithmetic** module with support for both JAX and PyTorch checkpoints (not tested thoroughly).
- [Feb 10 2026] χ₀ paper released.

Expand Down Expand Up @@ -208,9 +209,9 @@ Checkpoints are written to the config’s checkpoint directory. You can then use

- [x] kai0 oracle: training and inference code with non-advantage data of three tasks
- [x] Model Arithmetic: code of different baselines for weight-space interpolation
- [ ] Stage Advantage: code, data (advantage labels), and checkpoints — **Feb 12**
- [ ] HuggingFace & ModelScope: upload Stage Advantage data and checkpoints — **Feb 12**
- [ ] Train-Deploy Alignment — **Feb 15**
- [x] Stage Advantage: code, data (advantage labels), and checkpoints
- [ ] HuggingFace & ModelScope: upload Stage Advantage data and checkpoints — **Feb 14**
- [ ] Train-Deploy Alignment — **Feb 14**

## Model Arithmetic

Expand Down Expand Up @@ -265,11 +266,54 @@ python model_arithmetic/arithmetic_torch.py \

For gradient-based optimization, dataset splitting, and all other methods, see the full documentation in [`model_arithmetic/README.md`](model_arithmetic/README.md).

## Stage Advantage (Coming Soon)
## Stage Advantage

Stage Advantage decomposes long-horizon tasks into semantic stages and provides stage-aware advantage signals for policy training. It addresses the numerical instability of prior non-stage approaches by computing advantage as progress differentials within each stage, yielding smoother and more stable supervision.

**This module is currently under refinement and will be released soon.**
The full pipeline has four stages:

```
Stage 0: GT Labeling → Stage 1: Train Advantage Estimator → Stage 2: Advantage Estimation → Stage 3: AWBC Training
```

### Quick Start

**Stage 0 — GT Data Labeling**: Compute advantage values and discretize into `task_index` labels.

```bash
cd stage_advantage/annotation
python gt_label.py <dataset_path> \
--threshold 30 --chunk-size 50 --discretion-type binary \
--advantage-source absolute_advantage
```

For batch labeling across multiple dataset variants, see `stage_advantage/annotation/gt_labeling.sh`.

**Stage 1 — Train Advantage Estimator**: Fine-tune a pi0-based model to predict advantage from observations.

```bash
uv run python scripts/train_pytorch.py ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD --exp_name=run1 --save_interval 10000
```

For a ready-to-use script with environment setup (conda/venv activation, DDP configuration) and automatic log management, see `stage_advantage/annotation/train_estimator.sh`.

**Stage 2 — Advantage Estimation on New Data**: Use the trained estimator to label datasets with predicted advantage values.

```bash
uv run python stage_advantage/annotation/eval.py Flatten-Fold KAI0 /path/to/dataset
```

For a ready-to-use script with environment setup and status logging, see `stage_advantage/annotation/eval.sh`.

**Stage 3 — AWBC Training**: Train a policy with Advantage-Weighted Behavior Cloning.

```bash
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_flatten_fold_awbc --exp_name=run1
```

For a ready-to-use script with environment setup and automatic log management, see `stage_advantage/awbc/train_awbc.sh`.

For the full pipeline details, configuration instructions, and all parameters, see [`stage_advantage/README.md`](stage_advantage/README.md).

## Train-Deploy Alignment (Coming Soon)

Expand Down
65 changes: 51 additions & 14 deletions src/openpi/policies/agilex_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class AgilexInputs(transforms.DataTransformFn):
"""Inputs for the Agilex policy.

Expected inputs:
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
- images: dict[name, img] where img is [channel, height, width]. For normal pi05
training, names must be exactly the keys of required_rename_map. For advantage
estimator, optional_rename_map keys may be included as well.
- state: [14]
- actions: [action_horizon, 14]
"""
Expand All @@ -28,13 +30,23 @@ class AgilexInputs(transforms.DataTransformFn):

# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("top_head", "hand_left", "hand_right")

rename_map = {
required_rename_map = {
"top_head": "base_0_rgb",
"hand_left": "left_wrist_0_rgb",
"hand_right": "right_wrist_0_rgb"
}
# Optional cameras for advantage-estimator training (history frames).
optional_rename_map = {
"his_-100_top_head": "base_-100_rgb",
"his_-100_hand_left": "left_wrist_-100_rgb",
"his_-100_hand_right": "right_wrist_-100_rgb",
}

all_rename_map = {**required_rename_map, **optional_rename_map}

EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = tuple(required_rename_map.keys())
EXTRA_CAMERAS: ClassVar[tuple[str, ...]] = tuple(optional_rename_map.keys())

# if set all state to zeros
mask_state: bool = False
Expand All @@ -43,16 +55,22 @@ def __call__(self, data: dict) -> dict:
# We only mask padding for pi0 model, not pi0-FAST
mask_padding = self.model_type == _model.ModelType.PI0

in_images = data["images"]

if set(in_images) - set(self.EXPECTED_CAMERAS) - set(self.EXTRA_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")

# Pad the proprioceptive input to the action dimension of the model
state = transforms.pad_to_dim(data["state"], self.action_dim)
# Ensure state has correct shape [batch_size, state_dim]
state = state.squeeze()

# Parse images to uint8 (H,W,C) since LeRobot automatically stores as float32 (C,H,W)
images = {}
for camera in self.EXPECTED_CAMERAS:
if camera in data["images"]:
img = data["images"][camera]
image_masks = {}
for camera in self.EXPECTED_CAMERAS + self.EXTRA_CAMERAS:
if camera in in_images:
img = in_images[camera]
# Convert torch tensor to numpy array if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
Expand All @@ -62,12 +80,14 @@ def __call__(self, data: dict) -> dict:
# Convert from [C,H,W] to [H,W,C] if needed
if img.shape[0] == 3:
img = np.transpose(img, (1, 2, 0))
images[self.rename_map[camera]] = img
images[self.all_rename_map[camera]] = img
image_masks[self.all_rename_map[camera]] = np.True_

elif camera not in in_images and camera in self.EXTRA_CAMERAS:
continue # optional camera can be skipped
else:
raise ValueError(f"Camera {camera} not found in data")

# Create image mask based on available cameras
image_mask = {self.rename_map[camera]: np.True_ for camera in self.EXPECTED_CAMERAS}

# filter unnormal state / action value, set to 0
state = np.where(state > np.pi, 0, state)
Expand All @@ -77,7 +97,7 @@ def __call__(self, data: dict) -> dict:
masked_state = np.zeros_like(state) if self.mask_state else state
inputs = {
"image": images,
"image_mask": image_mask,
"image_mask": image_masks,
"state": masked_state,
}

Expand All @@ -91,17 +111,34 @@ def __call__(self, data: dict) -> dict:
action_mask = np.ones_like(actions, dtype=bool)
action_mask[:, self.action_dim:] = False
inputs["action_mask"] = action_mask

if self.convert_to_eef_position:
actions[..., :14] = batch_qpos_to_eef_pos(actions[..., :14])

inputs["actions"] = actions.squeeze()

# Add prompt if present
if "prompt" in data:
inputs["prompt"] = data["prompt"]


# Advantage-estimator optional fields: passthrough or convert to tensor
for key in ("frame_index", "episode_length", "progress", "image_original", "episode_index"):
if key in data:
inputs[key] = data[key]

def _to_tensor(x, default=None):
if x is None and default is not None:
return default
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
if isinstance(x, torch.Tensor):
return x.detach().clone()
raise NotImplementedError(f"Unsupported type: {type(x)}")

if "action_advantage" in data:
inputs["action_advantage"] = _to_tensor(data["action_advantage"], default=torch.tensor(1.0))
if "action_advantage_original" in data:
inputs["action_advantage_original"] = _to_tensor(data["action_advantage_original"])
return inputs


@dataclasses.dataclass(frozen=True)
class AgilexOutputs(transforms.DataTransformFn):
"""Outputs for the Agilex policy."""
Expand Down
69 changes: 53 additions & 16 deletions src/openpi/policies/arx_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
import openpi.models.model as _model
import openpi.transforms as transforms


@dataclasses.dataclass(frozen=True)
class ARXInputs(transforms.DataTransformFn):
"""Inputs for the ARX policy.

Expected inputs:
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
- images: dict[name, img] where img is [channel, height, width]. For normal pi05
training, names must be exactly the keys of required_rename_map. For advantage
estimator, optional_rename_map keys may be included as well.
- state: [14]
- actions: [action_horizon, 14]
"""
Expand All @@ -27,32 +30,47 @@ class ARXInputs(transforms.DataTransformFn):

# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("top_head", "hand_left", "hand_right")

rename_map = {
required_rename_map = {
"top_head": "base_0_rgb",
"hand_left": "left_wrist_0_rgb",
"hand_right": "right_wrist_0_rgb"
}
# Optional cameras for advantage-estimator training (history frames).
optional_rename_map = {
"his_-100_top_head": "base_-100_rgb",
"his_-100_hand_left": "left_wrist_-100_rgb",
"his_-100_hand_right": "right_wrist_-100_rgb",
}

all_rename_map = {**required_rename_map, **optional_rename_map}

EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = tuple(required_rename_map.keys())
EXTRA_CAMERAS: ClassVar[tuple[str, ...]] = tuple(optional_rename_map.keys())

# if set all state to zeros
mask_state: bool = False


def __call__(self, data: dict) -> dict:
# We only mask padding for pi0 model, not pi0-FAST
mask_padding = self.model_type == _model.ModelType.PI0

in_images = data["images"]

if set(in_images) - set(self.EXPECTED_CAMERAS) - set(self.EXTRA_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")

# Pad the proprioceptive input to the action dimension of the model
state = transforms.pad_to_dim(data["state"], self.action_dim)
# Ensure state has correct shape [batch_size, state_dim]
state = state.squeeze()

# Parse images to uint8 (H,W,C) since LeRobot automatically stores as float32 (C,H,W)
images = {}
for camera in self.EXPECTED_CAMERAS:
if camera in data["images"]:
img = data["images"][camera]
image_masks = {}
for camera in self.EXPECTED_CAMERAS + self.EXTRA_CAMERAS:
if camera in in_images:
img = in_images[camera]
# Convert torch tensor to numpy array if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
Expand All @@ -62,38 +80,57 @@ def __call__(self, data: dict) -> dict:
# Convert from [C,H,W] to [H,W,C] if needed
if img.shape[0] == 3:
img = np.transpose(img, (1, 2, 0))
images[self.rename_map[camera]] = img
images[self.all_rename_map[camera]] = img
image_masks[self.all_rename_map[camera]] = np.True_

elif camera not in in_images and camera in self.EXTRA_CAMERAS:
continue # optional camera can be skipped
else:
raise ValueError(f"Camera {camera} not found in data")

# Create image mask based on available cameras
image_mask = {self.rename_map[camera]: np.True_ for camera in self.EXPECTED_CAMERAS}

# Prepare inputs dictionary
masked_state = np.zeros_like(state) if self.mask_state else state
inputs = {
"image": images,
"image_mask": image_mask,
"image_mask": image_masks,
"state": masked_state,
}

# Add actions if present
if "actions" in data:
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
# actions = np.where(actions > np.pi, 0, actions)
# actions = np.where(actions < -np.pi, 0, actions)
actions = np.where(actions > np.pi, 0, actions)
actions = np.where(actions < -np.pi, 0, actions)
if mask_padding:
# Create action mask for padding
action_mask = np.ones_like(actions, dtype=bool)
action_mask[:, self.action_dim:] = False
inputs["action_mask"] = action_mask

inputs["actions"] = actions.squeeze()

# Add prompt if present
if "prompt" in data:
inputs["prompt"] = data["prompt"]


# Advantage-estimator optional fields: passthrough or convert to tensor
for key in ("frame_index", "episode_length", "progress", "image_original", "episode_index"):
if key in data:
inputs[key] = data[key]

def _to_tensor(x, default=None):
if x is None and default is not None:
return default
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
if isinstance(x, torch.Tensor):
return x.detach().clone()
raise NotImplementedError(f"Unsupported type: {type(x)}")

if "action_advantage" in data:
inputs["action_advantage"] = _to_tensor(data["action_advantage"], default=torch.tensor(1.0))
if "action_advantage_original" in data:
inputs["action_advantage_original"] = _to_tensor(data["action_advantage_original"])
return inputs


Expand Down
Loading
Loading