Official PyTorch implementation of the paper "Latent Particle World Models: Self-supervised Object-centric Stochastic Dynamics Modeling".
Tal Daniel • Carl Qi • Dan Haramati • Amir Zadeh • Chuan Li • Aviv Tamar • Deepak Pathak • David Held
Arxiv • Project Website • Video • OpenReview
Latent Particle World Models: Self-supervised Object-centric Stochastic Dynamics Modeling
Tal Daniel , Carl Qi, Dan Haramati, Amir Zadeh, Chuan Li, Aviv Tamar, Deepak Pathak, David HeldAbstract: We introduce Latent Particle World Model (LPWM), a self-supervised object-centric world model scaled to real-world multi-object datasets and applicable in decision-making. LPWM autonomously discovers keypoints, bounding boxes, and object masks directly from video data, enabling it to learn rich scene decompositions without supervision. Our architecture is trained end-to-end purely from videos and supports flexible conditioning on actions, language, and image goals. LPWM models stochastic particle dynamics via a novel latent action module and achieves state-of-the-art results on diverse real-world and synthetic datasets. Beyond stochastic video modeling, LPWM is readily applicable to decision-making, including goal-conditioned imitation learning, as we demonstrate in the paper.
Daniel, T., Qi, C., Haramati, D., Zadeh, A., Li, C., Tamar, A., Pathak, D., & Held, D. (2026). Latent particle world models: Self-supervised object-centric stochastic dynamics modeling. In The Fourteenth International Conference on Learning Representations (ICLR 2026). https://openreview.net/forum?id=lTaPtGiUUc
@inproceedings{
daniel2026latent,
title={Latent Particle World Models: Self-supervised Object-centric Stochastic Dynamics Modeling},
author={Tal Daniel and Carl Qi and Dan Haramati and Amir Zadeh and Chuan Li and Aviv Tamar and Deepak Pathak and David Held},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=lTaPtGiUUc}
}
# Install environment
conda env create -f environment.yml
conda activate dlp
# Train LPWM on Sketchy
python train_lpwm.py --dataset sketchy
# Generate videos with a pretrained model
python generate_lpwm_video_prediction.py --help- lpwm
- Latent Particle World Models: Self-supervised Object-centric Stochastic Dynamics Modeling
- Citation
- Prerequisites
- Model Zoo - Pretrained Models
- Datasets
- Conditioning Types, Multi-View and Examples
- Attribute Variable Names
- DLPv3 and LPWM - Training
- Training Logs, Progress Bar and Saved Images/Videos
- LPWM - Evaluation
- LPWM - Video Prediction and Generation with Pre-trained Models
- Example Usage, Documentation and Notebooks
- Repository Organization
We provide an environment.yml file which installs the required packages in a conda environment named torch.
Alternatively, you can use pip to install requirements.txt.
- Create the environment with:
conda env create -f environment.yml.
| Library | Version | Notes |
|---|---|---|
Python |
> = 3.9 |
- |
torch |
> = 2.6.0 |
- |
torchvision |
> = 0.21 |
- |
matplotlib |
> = 3.10.0 |
- |
numpy |
> = 1.24.3 |
- |
h5py |
> = 3.13.0 |
Some datasets (e.g., Balls) use H5/HDF |
py-opencv |
> = 4.11 |
For plotting |
tqdm |
> = 4.67.0 |
- |
scipy |
> = 1.15 |
- |
scikit-image |
> = 0.25.2 |
Required to generate the "Shapes" dataset |
ffmpeg |
= 4.2.2 |
Required to generate video files |
accelerate |
> = 1.5.0 |
For multi-GPU training |
imageio |
> = 2.6.1 |
For creating video GIFs |
piqa |
> = 1.3.1 |
For image evaluation metrics: LPIPS, SSIM, PSNR |
einops |
> = 0.81 |
- |
huggingface-hub |
> = 0.29 |
Downloading checkpoint and datasets |
notebook |
> = 6.5.4 |
To run Jupyter Notebooks |
For a manual installation guide, see docs/installation.md.
- We provide pre-trained checkpoints for datasets used in the paper.
- All model checkpoints should be placed inside the
/checkpointsdirectory.
The following table lists the available pre-trained checkpoints and where to download them.
| Model Type | Dataset | Link |
|---|---|---|
| LPWM | Sketchy (128x128) | MEGA.nz |
| LPWM-Action | Sketchy (128x128) | MEGA.nz |
| LPWM | BAIR (128x128) | MEGA.nz |
| LPWM-Language | LangaugeTable (128x128) | MEGA.nz |
| LPWM-Language | Bridge (128x128) | MEGA.nz |
| Dataset | Notes | Link |
|---|---|---|
| Sketchy | Original dataset from Deepmind. We use a subset and provide the pre-processed data on HF. | HF |
| Bridge | Pre-processed videos from the BRIDGE dataset with T5-large embeddings for the text instructions. If you want to prepare the data yourself, download from Open-X and follow datasets/bridge_preparation.py. |
HF |
| BAIR | We use a high-resolution version of the BAIR dataset, courtesy of PVG. | HF |
| LanguageTable | Follow download and pre-processing instructions in datasets/langtable_preparation.py. It requires tensorflow and the transformers libraries. |
datasets/langtable_preparation.py, Open-X |
| Panda | Expert imitation trajectories for several manipualation tasks from IsaacsGym simulator. Used in EC-Diffuser. | HF |
| OGBench | Offline RL trajectories from OGBench of scene and cube tasks. |
HF |
| Mario | A small dataset of Mario gameplay videos divided to 100-frame episodes. | HF |
| OBJ3D | Courtesy of G-SWM | Google Drive , HF |
| PHYRE | See datasets/phyre_preparation.py to generate new data or download the data we generated |
MEGA.nz |
| Balls | Synthetic, courtesy of G-SWM, see this link to generate or download data we generated | MEGA.nz |
| Shapes | Synthetic, generated on-the-fly | see generate_shape_dataset_torch() in datasets/shapes_ds.py |
| Custom Dataset | 1. Implement a Dataset (see examples in /datasets).2. Add it to get_image_dataset() and get_video_dataset() in /datasets/get_dataset.py.3. Prepare a json config file with the hyperparameters and place it in/configs. |
- |
| Condition Type | Required Flags | Dataset Example |
config.json Example |
|---|---|---|---|
| None (only latent actions) | - | /datasets/bair_ds.py, /datasets/mario_ds.py |
/configs/bair.json, /configs/mario.json |
| Action | "action_condition": true, "action_dim": 7 |
/datasets/sketchy_ds.py, /datasets/langtable_ds.py |
/configs/sketchy_action.json, /configs/langtable_action.json |
| Language | "language_condition": true, language_embed_dim": 1024 (T5), "language_max_len": 32 (max language tokens) |
/datasets/langtable_ds.py, /datasets/bridge_ds.py |
/configs/bridge.json, /configs/langtable.json |
| Image | "image_goal_condition": true |
/datasets/ogbench_ds.py, /datasets/panda_ds.py |
/configs/ogbench.json, /configs/panda.json |
| -- | -- | -- | -- |
| Multi-view | "n_views": 2 |
/datasets/panda_ds.py |
/configs/panda.json |
Most latent attributes variables names in the code follow the same convention from the paper (e.g., z_scale,
z_depth, ...).
Position (KP): z,
Transparency: z_obj_on / obj_on,
Visual features: z_features
You can train the models on single-GPU machines and multi-GPU machines. For multi-GPU training We use
HuggingFace Accelerate: pip install accelerate.
- Set visible GPUs under:
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"(NUM_GPUS=4) - Set "num_processes": NUM_GPUS in
accel_conf.yml(e.g."num_processes":4ifos.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3").
- Single-GPU machines:
python train_dlp.py -d {dataset}/python train_lpwm.py -d {dataset} - Multi-GPU machines:
accelerate launch --config_file ./accel_conf.yml train_dlp_accelerate.py -d {dataset}/accelerate launch --config_file ./accel_conf.yml train_lpwm_accelerate.py -d {dataset}
Config files for the datasets are located in /configs/{ds_name}.json. You can modify hyperparameters in these files.
To generate a config file for a new datasets you can copy one of the files in /confgis or use the
/configs/generate_config_file.py script.
Hyperparameters: See /docs/hyperparamters.md for extended details and recommended values.
The scripts train_dlp.py/train_lpwm.py or train_dlp_accelerate.py/train_lpwm_accelerate.py are using the config
file /configs/{ds_name}.json
Examples:
python train_dlp.py --dataset shapespython train_dlp.py --dataset obj3d_imgpython train_lpwm.py --dataset mariopython train_lpwm.py --dataset bridge
accelerate launch --config_file ./accel_conf.yml train_dlp_accelerate.py --dataset obj3d_imgaccelerate launch --config_file ./accel_conf.yml train_lpwm_accelerate.py --dataset obj3d128
- Note: if you want multiple multi-GPU runs, each run should have a different accelerate config file (
e.g.,
accel_conf.yml,accel_conf_2.yml, etc..). The only difference between the files should be themain_process_portfield (e.g., for the second config file, setmain_process_port: 81231).
During training, the script saves (locally) a log file with the metrics output and images in the following structure:
where the columns are different images in the batch (DLPv3) or an image sequence (LPWM) and the rows correspond to:
| Row | Image Meaning |
|---|---|
| 1 | Ground-truth (GT) original input image |
| 2 | GT image + all L (DLP)/M (LPWM) posterior particles |
| 3 | Reconstruction of the entire scene |
| 4 | GT image + all M prior keypoints (proposals) |
| 5 | GT image + top-K posterior particle filtered by their uncertainty |
| 6 | Foreground reconstruction (decoded glimpses and masks of the particles) |
| 7 | GT image + bounding boxes based on the scale attribute z_s (+ non-maximal suppression) |
| 8 | GT image + bounding boxes based on the decoded particles masks (+ non-maximal suppression) |
| 9 | Colored segementation masks (each particle's alpha map is colored differently) |
| 10 | Backgroung reconstruction |
During training, the script saves animation videos of rollouts with 3 panels:
| Ground Truth | Latent-actioned Conditioned | Sampling |
|---|---|---|
| The real video from the data | latent actions are extracted from GT video and used to condition the dynamics module with the first frame of GT video | Latent actions are sampled from te first frame |
- For image conditioning: the image will be plotted in the animation as "Goal".
- For language conditioning: the text will be added as title.
More examples are here.
Progress Bar: In addition to the losses (reconstruction and KLs), the progress bar displays:
on_l1- the average number of visible particles (the averagez_t). This is useful to diagnose whether objects are not detected (=0) or if the model is saturated (=MAX_PARTICELS) and all the information is in the FG (e.g., BG is not learned).a, b- the average values of the concentration parameters of the Beta distribution (Beta(a, b)) for the transparencyz_t(z_obj_on). They need to be balanced and not too large/small.smu- the averagemuof the scale attributez_scale(also the bounding boxes). Need to make sure not too large.
The evaluation protocol measures the reconstruction quality via 3 metrics: LPIPS, SSIM and PSNR and video generation via FVD.
We use the open-source piqa (pip install piqa) to compute the metrics. If eval_im_metrics=True in the config file,
the metrics will be computed every evaluation epoch on the validation set. The code saves the best model based
on the LPIPS metric.
To evaluate a pre-trained LPWM model (on the test set) on video prediction, we provide a script:
python eval/eval_gen_metrics.py --help
For example, to evaluate a model saved in /checkpoints/sketchy/sketchy_gddlp.pth on the test set, with 6 conditional
frames and
a generation horizon of 50 frames:
python eval/eval_gen_metrics.py -d sketchy -p ./checkpoints/sketchy --checkpoint ./checkpoints/sketchy/sketchy.pth --sample -b 10 -c 6 --horizon 50 --prefix "" --ctx
where b is the batch size and ctx specifies to use latent actions to generate the video (meaning that the script
will first extract the latent actions from the GT video and use them to condition the dynamics module).
The script will load the config file hparams.json from ./checkpoints/sketchy/hparams.json and the model
checkpoint from ./checkpoints/sketchy/sketchy.pth and use it to generate video
predictions. Make sure the path to the dataset is defined correctly in hparams.json.
For more options, see python eval/eval_gen_metrics.py --help.
For a similar evaluation of DLPv3 in the single-image setting, see eval_dlp_im_metric() in
/eval/eval_gen_metrics.py.
For FVD, we use: https://github.com/JunyaoHu/common_metrics_on_video_quality
Please download the networks checkpoints from the repository before evaluating the FVD (placed under
/eval/fvd/fvd/styleganv/i3d_torchscript.pt).
To evaluate a pre-trained LPWM model on video generation, we provide a script:
python eval/eval_fvd.py --help
For example, to evaluate a model saved in /checkpoints/sketchy/sketchy_gddlp.pth with 6 conditional frames and
a generation horizon of 50 frames:
python eval/eval_fvd.py -d sketchy -p ./checkpoints/sketchy --checkpoint ./checkpoints/sketchy/sketchy.pth --sample -b 4 -c 6 --horizon 50 --prefix "" --n_videos_per_clip 1
where b is the batch size and n_videos_per_clip specifies how many video to generate per one video in the dataset (
for the bair64 benchmark, this number is set to 100).
The script will load the config file hparams.json from ./checkpoints/sketchy/hparams.json and the model
checkpoint from ./checkpoints/sketchy/sketchy.pth and use it to generate video
predictions. Make sure the path to the dataset is defined correctly in hparams.json.
For more options, see python eval/eval_fvd.py --help.
To generate video predictions using a pre-trained model use generate_lpwm_video_prediction.py as follows:
python generate_lpwm_video_prediciton.py -d sketchy -p ./checkpoints/sketchy --checkpoint ./checkpoints/sketchy/sketchy.pth --sample -n 4 -c 6 --horizon 50 --prefix ""
The script will load the config file hparams.json from ./checkpoints/sketchy/hparams.json and the model
checkpoint from ./checkpoints/sketchy/sketchy.pth and use it to generate n video
predictions,
based on 6 conditional input frames, and a final video length of 50 frames. In the example above, four (n=4)
videos
will be generated and saved within a videos directory (will be created if it doesn't exist) under the checkpoint
directory. Make sure the path to the dataset is defined correctly in hparams.json.
For more options, see python generate_lpwm_video_prediction.py --help.
For your convenience, we provide more documentation in /docs and more examples of using the models in /notebooks.
| File | Content |
|---|---|
docs/installation.md |
Manual instructions to install packages with conda |
docs/hyperparameters.md |
Explanations of the various hyperparameters of the models and recommended values |
docs/example_usage.py |
Overview of the models functionality: forward output, loss calculation and sampling |
notebooks/dlpv3_lpwm_walkthrough_tutorial.ipynb |
Tutorial and walkthrough of DLPv3 and LPWM, where we train and evaluate a DLPv3 model on the shapes dataset |
notebooks/lpwm_gen_sketchy.ipynb / notebooks/lpwm_gen_langtable.ipynb |
Generating videos with pre-trained LPWMs and creating figures |
Please see our tutorial if you are you new to Deep Latent Particles, where you can train the model online:
notebooks/dlpv3_lpwm_walkthrough_tutorial.ipynb
| File name | Content |
|---|---|
/checkpoints |
directory for pre-trained checkpoints |
/assets |
directory containing sample images |
/datasets |
directory containing data loading classes for the various datasets |
/configs |
directory containing config files for the various datasets |
/docs |
various documentation files |
/notebooks |
various Jupyter Notebook examples of DLPv3 and LPWM |
/eval/eval_model.py |
evaluation functions such as evaluating the ELBO |
/eval/eval_gen_metrics.py |
evaluation functions for image metrics (LPIPS, PSNR, SSIM) |
/modules/modules.py |
basic neural network blocks used to implement DLPv3 |
/utils/loss_functions.py |
loss functions used to optimize the model such as Chamfer-KL and perceptual (VGG/LPIPS) loss |
/utils/util_func.py |
utility functions such as logging and plotting functions, Spatial Transformer Network (STN) |
models.py |
implementation of DLPv3 and LPWM |
train_dlp.py/train_lpwm.py |
training function of DLP/LPWM for single-GPU machines |
train_dlp_accelerate.py/train_lpwm_accelerate.py |
training function of DLP/LPWM for multi-GPU machines |
environment.yml |
Anaconda environment file to install the required dependencies |
requirements.txt |
requirements file for pip |
accel_conf.yml |
configuration file for accelerate to run training on multiple GPUs |






