Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
76 changes: 60 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

- **[Model Arithmetic](#model-arithmetic)**: A weight-space merging strategy that combines models trained on different data subsets, efficiently capturing diverse knowledge without architectural complexity. **[Released]**
- **[Stage Advantage](#stage-advantage)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Released]**
- **[Train-Deploy Alignment](#train-deploy-alignment-coming-soon)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Coming Soon]**
- **[Train-Deploy Alignment](#train-deploy-alignment)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Released]**

χ₀ enables two sets of dual-arm robots to collaboratively orchestrate long-horizon garment manipulation — flattening, folding, and hanging — surpassing the state-of-the-art $\pi_{0.5}$ baseline by approximately 250% in success rate, with `only 20 hours of data and 8 A100 GPUs`.

Expand All @@ -47,13 +47,15 @@ https://github.com/user-attachments/assets/3f5f0c48-ff3f-4b9b-985b-59ad0b2ea97c
- [Workflow](#workflow)
- [Quick Start](#quick-start)
- [Stage Advantage](#stage-advantage)
- [Train-Deploy Alignment (Coming Soon)](#train-deploy-alignment-coming-soon)
- [Train-Deploy Alignment](#train-deploy-alignment)
- [Citation](#licenseandcitation)
- [Troubleshooting](#troubleshooting)
- [Links and Community](#links-and-community)

## Update

- [Feb 15 2026] Stage Advantage **advantage labels** (`Task_A/advantage/`) released on [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) and [ModelScope](https://www.modelscope.cn/datasets/OpenDriveLab/Kai0).
- [Feb 15 2026] Release of the **Train-Deploy Alignment** module: data augmentation (time scaling, space mirroring), DAgger data collection, inference with temporal smoothing/ensembling and RTC, and HDF5-to-LeRobot conversion.
- [Feb 14 2026] Release of the **Stage Advantage** module: advantage estimator training, evaluation, GT labeling, and AWBC training pipeline.
- [Feb 10 2026] Initial release of the **Model Arithmetic** module with support for both JAX and PyTorch checkpoints (not tested thoroughly).
- [Feb 10 2026] χ₀ paper released.
Expand All @@ -80,7 +82,7 @@ Non-edge components (e.g., Policy Training, Model Arithmetic) have been tested o

### Hardware

For real-robot deployment (dual-arm setup, cameras, and table layout), see **[Hardware Setup & 3D Print Files](setup/README.md)**. That document covers supported platforms (Agilex Piper for FlattenFold / TeeShirtSort, ARX X5 for HangCloth), Intel RealSense D435i camera placement, 3D-printed grippers and mounts with usage notes, and inference host GPU (RTX 4090 in Ubuntu 20.04).
For real-robot deployment (dual-arm setup, cameras, and table layout), see **[Hardware Setup & 3D Print Files](setup/README.md)**. That document covers supported platforms (Agilex Piper for Task_A / Task_B, ARX X5 for Task_C), Intel RealSense D435i camera placement, 3D-printed grippers and mounts with usage notes, and inference host GPU (RTX 4090 in Ubuntu 20.04).

## Installation

Expand Down Expand Up @@ -116,11 +118,11 @@ Download the Kai0 dataset so it is available under `./data` for training and eva
python scripts/download_dataset.py
```

This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).
This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (Task_A, Task_B, Task_C). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).

### 2. Download checkpoints (optional, for testing)

We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).
We provide **one best model per task** (Task_A, Task_B, Task_C) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).

From the repository root, you can download all best-model checkpoints to `./checkpoints` with:

Expand All @@ -131,7 +133,7 @@ python scripts/download_checkpoints.py
To download only specific tasks or use a custom path, run:

```bash
python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_checkpoints
python scripts/download_checkpoints.py --tasks Task_A Task_C --local-dir ./my_checkpoints
```

After download, set `weight_loader` in the training config to the path of the corresponding checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
Expand All @@ -144,7 +146,7 @@ After the dataset is in `./data`, you can run **normal π₀.₅ full fine-tunin

Edit [`src/openpi/training/config.py`](src/openpi/training/config.py) (around lines 1173–1226) for the task(s) you need:

- **`repo_id`**: set to the **absolute path** to the dataset subset, e.g. `<path_to_repo_root>/data/FlattenFold/base`, `<path_to_repo_root>/data/TeeShirtSort/base`, or `<path_to_repo_root>/data/HangCloth/base`.
- **`repo_id`**: set to the **absolute path** to the dataset subset, e.g. `<path_to_repo_root>/data/Task_A/base`, `<path_to_repo_root>/data/Task_B/base`, or `<path_to_repo_root>/data/Task_C/base`.
- **`weight_loader`**: set to the path of your **π₀.₅ base checkpoint** — either the best model you downloaded in step 2 above, or openpi’s pretrained π₀.₅ checkpoint.

Config names to use: e.g. `pi05_flatten_fold_normal`
Expand Down Expand Up @@ -210,8 +212,8 @@ Checkpoints are written to the config’s checkpoint directory. You can then use
- [x] kai0 oracle: training and inference code with non-advantage data of three tasks
- [x] Model Arithmetic: code of different baselines for weight-space interpolation
- [x] Stage Advantage: code, data (advantage labels), and checkpoints
- [ ] HuggingFace & ModelScope: upload Stage Advantage data and checkpoints — **Feb 14**
- [ ] Train-Deploy Alignment — **Feb 14**
- [x] Train-Deploy Alignment: data augmentation, DAgger, inference (temporal smoothing, ensembling, RTC)
- [x] HuggingFace & ModelScope: Stage Advantage data (`Task_A/advantage/`) and checkpoints uploaded

## Model Arithmetic

Expand Down Expand Up @@ -300,7 +302,7 @@ For a ready-to-use script with environment setup (conda/venv activation, DDP con
**Stage 2 — Advantage Estimation on New Data**: Use the trained estimator to label datasets with predicted advantage values.

```bash
uv run python stage_advantage/annotation/eval.py Flatten-Fold KAI0 /path/to/dataset
uv run python stage_advantage/annotation/eval.py Task-A KAI0 /path/to/dataset
```

For a ready-to-use script with environment setup and status logging, see `stage_advantage/annotation/eval.sh`.
Expand All @@ -315,14 +317,56 @@ For a ready-to-use script with environment setup and automatic log management, s

For the full pipeline details, configuration instructions, and all parameters, see [`stage_advantage/README.md`](stage_advantage/README.md).

## Train-Deploy Alignment (Coming Soon)
## Train-Deploy Alignment

Train-Deploy Alignment bridges the distribution gap between training and real-world deployment through:
- **Spatio-temporal augmentation**: Data augmentation including space mirroring and time scaling for dual-arm setups.
- **Heuristic DAgger corrections**: Interactive on-robot data collection for iterative policy improvement.
- **Temporal chunk-wise smoothing**: Smoothed action execution to reduce jitter during deployment.
Train-Deploy Alignment bridges the distribution gap between training and real-world deployment through three sub-modules:

**This module is currently under refinement and will be released soon.**
- **Data Augmentation** (`train_deploy_alignment/data_augment/`): Time scaling (frame extraction at configurable rates), space mirroring (left/right arm swap + video flip), dataset merging, and HDF5-to-LeRobot format conversion.
- **DAgger** (`train_deploy_alignment/dagger/`): Policy-in-the-loop data collection for both Agilex Piper and ARX X5 platforms. Operators run inference, switch to DAgger mode for human corrections, and save episodes (HDF5 + optional videos + intervention labels).
- **Inference** (`train_deploy_alignment/inference/`): Deployment code for Agilex and ARX robots with multiple execution modes — synchronous, temporal smoothing, temporal ensembling, and **RTC (real-time chunking)**. Uses a two-machine setup (GPU policy server + robot IPC client).

### Quick Start

**Data Augmentation — Time scaling:**

```bash
python train_deploy_alignment/data_augment/time_scaling.py \
--src_path /path/to/source --tgt_path /path/to/extracted --repo_id extracted_dataset \
--extraction_factor 2
```

**Data Augmentation — Space mirroring (mirror + merge):**

```bash
python train_deploy_alignment/data_augment/space_mirroring.py full \
--src-path /path/to/original --mirror-path /path/to/mirrored --merge-path /path/to/merged \
--repo-id my_dataset
```

**DAgger — Agilex:** Start the policy server on the GPU host, then on the IPC:

```bash
conda activate kai0_inference
python train_deploy_alignment/dagger/agilex/agilex_openpi_dagger_collect.py \
--host <gpu_host_ip> --port 8000 --ctrl_type joint --use_temporal_smoothing --chunk_size 50 \
--dataset_name <your_dataset_name>
```

**Inference — Agilex (temporal smoothing):** Start the policy server on the GPU host, then on the IPC:

```bash
conda activate kai0_inference
python inference/agilex_inference_openpi_temporal_smoothing.py \
--host <gpu_host_ip> --port 8000 --ctrl_type joint --use_temporal_smoothing --chunk_size 50
```

**Inference — ARX (RTC mode):** Start the policy server with an RTC config, then on the IPC:

```bash
python inference/arx_openpi_inference_rtc.py --host <gpu_host_ip> --port 8000 --rtc_mode --chunk_size 50
```

For full setup instructions (IPC environment, CAN, ROS/ROS2, platform-specific details), see [`train_deploy_alignment/README.md`](train_deploy_alignment/README.md).

## License and Citation

Expand Down
20 changes: 10 additions & 10 deletions setup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ Quick reference for deploying and debugging hardware for the supported task plat

## Table of Contents

- [1. FlattenFold / TeeShirtSort (Agilex Piper)](#1-flattenfold--teeshirtsort-agilex-piper)
- [2. HangCloth (ARX X5)](#2-hangcloth-arx-x5)
- [1. Task_A / Task_B (Agilex Piper)](#1-task_a--task_b-agilex-piper)
- [2. Task_C (ARX X5)](#2-task_c-arx-x5)
- [3. Inference Host](#3-inference-host)

---

## 1. FlattenFold / TeeShirtSort (Agilex Piper)
## 1. Task_A / Task_B (Agilex Piper)

**Directories:** `FlattenFold/`, `TeeShirtSort/`
**Directories:** `Task_A/`, `Task_B/`

### 1.1 Components

Expand All @@ -24,7 +24,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
| Cameras | Intel RealSense D435 (triple-camera setup) |
| Printed parts | Left/right wrist camera mounts, center camera mount, center camera base |

### 1.2 FlattenFold Layout
### 1.2 Task_A Layout

| Parameter | Value |
|-----------|-------|
Expand All @@ -37,7 +37,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
| Right primary arm → table front edge | 12 cm |
| Left–right primary arm center distance | 39 cm |

### 1.3 TeeShirtSort Layout (demoA-style)
### 1.3 Task_B Layout (demoA-style)

| Parameter | Value |
|-----------|-------|
Expand All @@ -50,7 +50,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
| Right primary arm → table front edge | 11 cm |
| Left–right primary arm center distance | 40 cm |

### 1.4 3D Models — Usage (FlattenFold / TeeShirtSort)
### 1.4 3D Models — Usage (Task_A / Task_B)

#### Gripper (end-effector)

Expand All @@ -77,9 +77,9 @@ Quick reference for deploying and debugging hardware for the supported task plat

---

## 2. HangCloth (ARX X5)
## 2. Task_C (ARX X5)

**Directory:** `HangCloth/`
**Directory:** `Task_C/`

### 2.1 Components

Expand All @@ -102,7 +102,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
| Right primary arm → table front edge | 11 cm |
| Left–right primary arm center distance | 53 cm |

### 2.3 3D Models — Usage (HangCloth)
### 2.3 3D Models — Usage (Task_C)

#### Grippers (secondary arms)

Expand Down
2 changes: 2 additions & 0 deletions src/openpi/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ModelType(enum.Enum):
PI0 = "pi0"
PI0_FAST = "pi0_fast"
PI05 = "pi05"
PI0_RTC = "pi0_rtc"
PI05_RTC = "pi05_rtc"


# The model always expects these images
Expand Down
28 changes: 28 additions & 0 deletions src/openpi/models/pi0_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

if TYPE_CHECKING:
from openpi.models.pi0 import Pi0
from openpi.models.pi0_rtc import Pi0RTC


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -107,6 +108,33 @@ def get_freeze_filter(self) -> nnx.filterlib.Filter:
return nnx.Nothing
return nnx.All(*filters)


@dataclasses.dataclass(frozen=True)
class Pi0RTCConfig(Pi0Config):
"""Config for Pi0RTC (real-time control) model. Uses same architecture as Pi0/Pi05 but sample_actions supports
prev_action_chunk, inference_delay, execute_horizon for RTC guidance. Use this config when serving
for RTC inference (e.g. agilex_inference_openpi_rtc.py). Set pi05=True for Pi05-based RTC (model_type PI05_RTC)."""

@property
@override
def model_type(self) -> _model.ModelType:
return _model.ModelType.PI05_RTC if self.pi05 else _model.ModelType.PI0_RTC

@override
def create(self, rng: at.KeyArrayLike) -> "Pi0RTC":
from openpi.models.pi0_rtc import Pi0RTC

return Pi0RTC(self, rngs=nnx.Rngs(rng))

@override
def load_pytorch(self, train_config, weight_path: str):
"""RTC model is JAX-only; use a JAX checkpoint with serve_policy and Pi0RTCConfig."""
raise NotImplementedError(
"Pi0RTC is only supported with JAX checkpoints. Use a checkpoint saved from OpenPi JAX training "
"(params directory, not model.safetensors) and serve with --policy.config=pi05_rtc_flatten_fold_inference (or your RTC config name)."
)


@dataclasses.dataclass(frozen=True)
class AdvantageEstimatorConfig(Pi0Config):
# * Custom
Expand Down
2 changes: 1 addition & 1 deletion src/openpi/models/pi0_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def rtc_step(carry):
x_t_for_denoise = x_t
if mask_prefix_delay and provided_dim > 0:
mask_time = (jnp.arange(self.action_horizon) < d).astype(bool)[None, :, None]
# 仅覆盖提供的维度,其余保持 x_t 原值
# Overwrite only the provided dims in the delay prefix; leave the rest as x_t.
overwrite = jnp.where(mask_time, prev_chunk[..., :provided_dim], x_t_for_denoise[..., :provided_dim])
x_t_for_denoise = x_t_for_denoise.at[..., :provided_dim].set(overwrite)

Expand Down
4 changes: 2 additions & 2 deletions src/openpi/policies/agilex_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class AgilexInputs(transforms.DataTransformFn):
mask_state: bool = False

def __call__(self, data: dict) -> dict:
# We only mask padding for pi0 model, not pi0-FAST
mask_padding = self.model_type == _model.ModelType.PI0
# We only mask padding for pi0/pi0_rtc model, not pi05/pi05_rtc or pi0-FAST
mask_padding = self.model_type in (_model.ModelType.PI0, _model.ModelType.PI0_RTC)

in_images = data["images"]

Expand Down
4 changes: 2 additions & 2 deletions src/openpi/policies/arx_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class ARXInputs(transforms.DataTransformFn):
mask_state: bool = False

def __call__(self, data: dict) -> dict:
# We only mask padding for pi0 model, not pi0-FAST
mask_padding = self.model_type == _model.ModelType.PI0
# We only mask padding for pi0/pi0_rtc model, not pi05/pi05_rtc or pi0-FAST
mask_padding = self.model_type in (_model.ModelType.PI0, _model.ModelType.PI0_RTC)

in_images = data["images"]

Expand Down
2 changes: 1 addition & 1 deletion src/openpi/policies/droid_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, data: dict) -> dict:
wrist_image = _parse_image(data["observation/wrist_image_left"])

match self.model_type:
case _model.ModelType.PI0 | _model.ModelType.PI05:
case _model.ModelType.PI0 | _model.ModelType.PI05 | _model.ModelType.PI0_RTC | _model.ModelType.PI05_RTC:
names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
images = (base_image, wrist_image, np.zeros_like(base_image))
image_masks = (np.True_, np.True_, np.False_)
Expand Down
23 changes: 20 additions & 3 deletions src/openpi/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ModelTransformFactory(GroupFactory):

def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
match model_config.model_type:
case _model.ModelType.PI0:
case _model.ModelType.PI0 | _model.ModelType.PI0_RTC:
return _transforms.Group(
inputs=[
_transforms.InjectDefaultPrompt(self.default_prompt),
Expand All @@ -126,7 +126,7 @@ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
_transforms.PadStatesAndActions(model_config.action_dim),
],
)
case _model.ModelType.PI05:
case _model.ModelType.PI05 | _model.ModelType.PI05_RTC:
assert isinstance(model_config, pi0_config.Pi0Config)
return _transforms.Group(
inputs=[
Expand Down Expand Up @@ -187,7 +187,7 @@ def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.Bas
repo_id=repo_id,
asset_id=asset_id,
norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
use_quantile_norm=model_config.model_type != ModelType.PI0,
use_quantile_norm=model_config.model_type not in (ModelType.PI0, ModelType.PI0_RTC),
)

def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
Expand Down Expand Up @@ -1371,6 +1371,23 @@ def __post_init__(self) -> None:
num_workers=8,
batch_size=256,
),

#**************************FlattenFold RTC Inference*******************************
# Use this config when serving the policy for agilex_inference_openpi_rtc.py (JAX checkpoints only).
TrainConfig(
name="pi05_rtc_flatten_fold_inference",
model=pi0_config.Pi0RTCConfig(pi05=True),
data=LerobotAgilexDataConfig(
repo_id="<path_to_repo_root>/data/FlattenFold/base",
default_prompt="Flatten and fold the cloth.",
use_delta_joint_actions=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader("<path_to/pi05_base/checkpoint>"),
num_train_steps=100_000,
keep_period=5000,
num_workers=8,
batch_size=256,
),
# RoboArena & PolaRiS configs.
*roboarena_config.get_roboarena_configs(),
*polaris_config.get_polaris_configs(),
Expand Down
Loading
Loading