Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
77e8736
Refactor pose visualization in vis_in_the_wild.py
xiu-cs Feb 5, 2026
5001c3c
Update vis_in_the_wild.sh for model path and GPU configuration
xiu-cs Feb 5, 2026
5091074
Refactor argument parsing in arguments.py
xiu-cs Feb 5, 2026
fb14cf5
Update FMPose3D_test.sh to reflect new model paths
xiu-cs Feb 5, 2026
500e097
Update vis_in_the_wild.sh to standardize model weights path
xiu-cs Feb 5, 2026
82244ce
Update FMPose3D_train.sh for model path consistency
xiu-cs Feb 5, 2026
dbd2afa
Refactor aggregation method in aggregation_methods.py
xiu-cs Feb 5, 2026
d6dcf06
Add model_path argument in arguments.py and remove saved_model_path
xiu-cs Feb 5, 2026
db005f3
Refactor file backup process in FMPose3D_main.py
xiu-cs Feb 5, 2026
0a8487a
Remove weight_softmax_tau variable from FMPose3D_test.sh for consiste…
xiu-cs Feb 5, 2026
6608049
fix the path error
xiu-cs Feb 5, 2026
af6ec60
Update model path variable in vis_in_the_wild.py to align with recent…
xiu-cs Feb 8, 2026
b634a75
correct the color of joints
xiu-cs Feb 8, 2026
ea6e291
Add demo GIF for visual representation
xiu-cs Feb 8, 2026
c6c148c
Update demo image in README.md from JPG to GIF for enhanced visual re…
xiu-cs Feb 8, 2026
d76b239
update the model structure
xiu-cs Feb 9, 2026
1c3ca70
Revise README for clarity and updates
MMathisLab Feb 6, 2026
e02edd5
Update torch.load weigths_only=True
deruyter92 Feb 9, 2026
1dbea87
fix README broken link and typo
deruyter92 Feb 9, 2026
82149c9
Replace torch Variable with torch tensor
deruyter92 Feb 9, 2026
a184cb0
update cuda fallback
deruyter92 Feb 9, 2026
1739378
update gitignore
deruyter92 Feb 9, 2026
077eaa6
rename get_varialbe -> get_variable everywhere
deruyter92 Feb 9, 2026
81a22c6
Fix sys.path imports -> proper module references
deruyter92 Feb 9, 2026
f3d6ba8
Apply suggestion from @Copilot
xiu-cs Feb 9, 2026
67268c6
Apply suggestion from @Copilot
xiu-cs Feb 9, 2026
d74a76d
Apply suggestion from @Copilot
xiu-cs Feb 9, 2026
f8ab0a6
Add config dataclasses (in parallel to arguments.py)
deruyter92 Feb 9, 2026
db775ea
add tests for config.py
deruyter92 Feb 9, 2026
4611178
Apply suggestion from @Copilot
xiu-cs Feb 9, 2026
6b5d354
Apply suggestion from @Copilot
xiu-cs Feb 9, 2026
a4df14b
Merge pull request #12 from deruyter92/jaap/minor_refactors
xiu-cs Feb 9, 2026
b73d9ce
Feat: add extendable model registry
deruyter92 Feb 9, 2026
50125e5
change demo script add example for human pose model
deruyter92 Feb 9, 2026
5677b53
Merge branch 'ti_video_demo' into jaap/add_config_and_registry
xiu-cs Feb 9, 2026
4c8b201
update config: replace model_path with model_type from the registry
deruyter92 Feb 9, 2026
acb4feb
Merge branch 'jaap/add_config_and_registry' of github.com:deruyter92/…
deruyter92 Feb 9, 2026
3863ddf
Update config: extendable configs and changed name -> PipelineConfig
deruyter92 Feb 10, 2026
c4bf891
Add HRNet model api
deruyter92 Feb 10, 2026
0827a1b
Add high-level inference API for FMPose3D (fmpose3d/fmpose3d.py)
deruyter92 Feb 10, 2026
bad89d7
Merge branch 'main' into feat/add_api
xiu-cs Feb 10, 2026
1023b4e
Add documentation header for FMPose3D in HRNet files
xiu-cs Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ htmlcov/
*.pkl
*.h5
*.ckpt

# Excluded directories
pre_trained_models/
demo/predictions/
demo/images/
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

![Version](https://img.shields.io/badge/python_version-3.10-purple)
[![PyPI version](https://badge.fury.io/py/fmpose3d.svg?icon=si%3Apython)](https://badge.fury.io/py/fmpose3d)
[![License: LApache 2.0](https://img.shields.io/badge/License-Apache2.0-blue.svg)](https://www.gnu.org/licenses/apach2.0)
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0)

This is the official implementation of the approach described in the preprint:

[**FMPose3D: monocular 3D pose estimation via flow matching**](http://arxiv.org/abs/2602.05755)
[**FMPose3D: monocular 3D pose estimation via flow matching**](https://arxiv.org/abs/2602.05755)
Ti Wang, Xiaohang Yu, Mackenzie Weygandt Mathis

<!-- <p align="center"><img src="./images/Frame 4.jpg" width="50%" alt="" /></p> -->
Expand Down Expand Up @@ -51,7 +51,7 @@ sh vis_in_the_wild.sh
```
The predictions will be saved to folder `demo/predictions`.

<p align="center"><img src="./images/demo.jpg" width="95%" alt="" /></p>
<p align="center"><img src="./images/demo.gif" width="95%" alt="" /></p>

## Training and Inference

Expand Down Expand Up @@ -79,7 +79,7 @@ The training logs, checkpoints, and related files of each training time will be

For training on Human3.6M:
```bash
sh /scripts/FMPose3D_train.sh
sh ./scripts/FMPose3D_train.sh
```

### Inference
Expand Down
14 changes: 7 additions & 7 deletions animals/demo/vis_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

# SuperAnimal Demo: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb
import sys
import os
import numpy as np
import glob
Expand All @@ -25,8 +24,6 @@
from fmpose3d.animals.common.arguments import opts as parse_args
from fmpose3d.common.camera import normalize_screen_coordinates, camera_to_world

sys.path.append(os.getcwd())

args = parse_args().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

Expand Down Expand Up @@ -334,13 +331,15 @@ def get_pose3D(path, output_dir, type='image'):
print(f"args.n_joints: {args.n_joints}, args.out_joints: {args.out_joints}")

## Reload model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = {}
model['CFM'] = CFM(args).cuda()
model['CFM'] = CFM(args).to(device)

model_dict = model['CFM'].state_dict()
model_path = args.saved_model_path
print(f"Loading model from: {model_path}")
pre_dict = torch.load(model_path)
pre_dict = torch.load(model_path, map_location=device, weights_only=True)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
Expand Down Expand Up @@ -400,7 +399,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
input_2D = np.expand_dims(input_2D, axis=0) # (1, J, 2)

# Convert to tensor format matching visualize_animal_poses.py
input_2D = torch.from_numpy(input_2D.astype('float32')).cuda() # (1, J, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_2D = torch.from_numpy(input_2D.astype('float32')).to(device) # (1, J, 2)
input_2D = input_2D.unsqueeze(0) # (1, 1, J, 2)

# Euler sampler for CFM
Expand All @@ -418,7 +418,7 @@ def euler_sample(c_2d, y_local, steps, model_3d):

# Single inference without flip augmentation
# Create 3D random noise with shape (1, 1, J, 3)
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3).cuda()
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3, device=device)
output_3D = euler_sample(input_2D, y, steps=args.sample_steps, model_3d=model)

output_3D = output_3D[0:, args.pad].unsqueeze(1)
Expand Down
8 changes: 5 additions & 3 deletions animals/scripts/main_animal3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st
# gt_3D shape: torch.Size([B, J, 4]) (x,y,z + homogeneous coordinate)
gt_3D = gt_3D[:,:,:3] # only use x,y,z for 3D ground truth

# [input_2D, gt_3D, batch_cam, vis_3D] = get_varialbe(split, [input_2D, gt_3D, batch_cam, vis_3D])
# [input_2D, gt_3D, batch_cam, vis_3D] = get_variable(split, [input_2D, gt_3D, batch_cam, vis_3D])

# unsqueeze frame dimension
input_2D = input_2D.unsqueeze(1) # (B,F,J,C)
Expand Down Expand Up @@ -264,15 +264,17 @@ def get_parameter_number(net):
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=int(args.workers), pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = {}
model['CFM'] = CFM(args).cuda()
model['CFM'] = CFM(args).to(device)

if args.reload:
model_dict = model['CFM'].state_dict()
# Prefer explicit saved_model_path; otherwise fallback to previous_dir glob
model_path = args.saved_model_path
print(model_path)
pre_dict = torch.load(model_path)
pre_dict = torch.load(model_path, weights_only=True, map_location=device)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
Expand Down
73 changes: 40 additions & 33 deletions demo/vis_in_the_wild.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Licensed under Apache 2.0
"""

import sys
import cv2
import os
import numpy as np
Expand All @@ -16,8 +15,6 @@
from tqdm import tqdm
import copy

sys.path.append(os.getcwd())

# Auto-download checkpoint files if missing
from fmpose3d.lib.checkpoint.download_checkpoints import ensure_checkpoints
ensure_checkpoints()
Expand All @@ -28,17 +25,10 @@

args = parse_args().parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
if getattr(args, 'model_path', ''):
import importlib.util
import pathlib
model_abspath = os.path.abspath(args.model_path)
module_name = pathlib.Path(model_abspath).stem
spec = importlib.util.spec_from_file_location(module_name, model_abspath)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
CFM = getattr(module, 'Model')


from fmpose3d.models import get_model
CFM = get_model(args.model_type)

from fmpose3d.common.camera import *

import matplotlib
Expand All @@ -50,15 +40,27 @@
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

def show2Dpose(kps, img):
connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
[5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
[8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]]
# Shared skeleton definition so 2D/3D segment colors match
SKELETON_CONNECTIONS = [
[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
[5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
[8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]
]
# LR mask for skeleton segments: True -> left color, False -> right color
SKELETON_LR = np.array(
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
dtype=bool,
)

LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool)
def show2Dpose(kps, img):
connections = SKELETON_CONNECTIONS
LR = SKELETON_LR

lcolor = (255, 0, 0)
rcolor = (0, 0, 255)
# lcolor = (240, 176, 0)
# rcolor = (240, 176, 0)

thickness = 3

for j,c in enumerate(connections):
Expand All @@ -67,8 +69,8 @@ def show2Dpose(kps, img):
start = list(start)
end = list(end)
cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness)
cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
# cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
# cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)

return img

Expand All @@ -77,11 +79,13 @@ def show3Dpose(vals, ax):

lcolor=(0,0,1)
rcolor=(1,0,0)

I = np.array( [0, 0, 1, 4, 2, 5, 0, 7, 8, 8, 14, 15, 11, 12, 8, 9])
J = np.array( [1, 4, 2, 5, 3, 6, 7, 8, 14, 11, 15, 16, 12, 13, 9, 10])

LR = np.array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], dtype=bool)
# lcolor=(0/255, 176/255, 240/255)
# rcolor=(0/255, 176/255, 240/255)


I = np.array([c[0] for c in SKELETON_CONNECTIONS])
J = np.array([c[1] for c in SKELETON_CONNECTIONS])
LR = SKELETON_LR

for i in np.arange( len(I) ):
x, y, z = [np.array( [vals[I[i], j], vals[J[i], j]] ) for j in range(3)]
Expand Down Expand Up @@ -199,7 +203,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):

input_2D = input_2D[np.newaxis, :, :, :, :]

input_2D = torch.from_numpy(input_2D.astype('float32')).cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_2D = torch.from_numpy(input_2D.astype('float32')).to(device)

N = input_2D.size(0)

Expand All @@ -215,10 +220,10 @@ def euler_sample(c_2d, y_local, steps, model_3d):

## estimation

y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
output_3D_non_flip = euler_sample(input_2D[:, 0], y, steps=args.sample_steps, model_3d=model)

y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
output_3D_flip = euler_sample(input_2D[:, 1], y_flip, steps=args.sample_steps, model_3d=model)

output_3D_flip[:, :, :, 0] *= -1
Expand Down Expand Up @@ -266,14 +271,16 @@ def get_pose3D(path, output_dir, type='image'):
# args.type = type

## Reload
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = {}
model['CFM'] = CFM(args).cuda()
model['CFM'] = CFM(args).to(device)

# if args.reload:
model_dict = model['CFM'].state_dict()
model_path = args.saved_model_path
model_path = args.model_weights_path
print(model_path)
pre_dict = torch.load(model_path)
pre_dict = torch.load(model_path, map_location=device, weights_only=True)
for name, key in model_dict.items():
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
Expand Down Expand Up @@ -336,7 +343,7 @@ def get_pose3D(path, output_dir, type='image'):
## save
output_dir_pose = output_dir +'pose/'
os.makedirs(output_dir_pose, exist_ok=True)
plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.jpg', dpi=200, bbox_inches = 'tight')
plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.png', dpi=200, bbox_inches = 'tight')


if __name__ == "__main__":
Expand Down
17 changes: 9 additions & 8 deletions demo/vis_in_the_wild.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
#Test
layers=5
gpu_id=1
gpu_id=0
sample_steps=3
batch_size=1
sh_file='vis_in_the_wild.sh'

model_path='../pre_trained_models/fmpose_detected2d/model_GAMLP.py'
saved_model_path='../pre_trained_models/fmpose_detected2d/FMpose_36_4972_best.pth'
model_type='fmpose3d'
model_weights_path='../pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth'

# path='./images/image_00068.jpg' # single image
input_images_folder='./images/' # folder containing multiple images
target_path='./images/' # folder containing multiple images
# target_path='./images/xx.png' # single image
# target_path='./videos/xxx.mp4' # video path

python3 vis_in_the_wild.py \
--type 'image' \
--path ${input_images_folder} \
--saved_model_path "${saved_model_path}" \
--model_path "${model_path}" \
--path ${target_path} \
--model_weights_path "${model_weights_path}" \
--model_type "${model_type}" \
--sample_steps ${sample_steps} \
--batch_size ${batch_size} \
--layers ${layers} \
Expand Down
32 changes: 32 additions & 0 deletions fmpose3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,49 @@
aggregation_RPEA_joint_level,
)

# Configuration dataclasses
from .common.config import (
FMPose3DConfig,
HRNetConfig,
InferenceConfig,
ModelConfig,
PipelineConfig,
)

# High-level inference API
from .fmpose3d import (
FMPose3DInference,
HRNetEstimator,
Pose2DResult,
Pose3DResult,
Source,
)

# Import 2D pose detection utilities
from .lib.hrnet.gen_kpts import gen_video_kpts
from .lib.hrnet.hrnet import HRNetPose2d
from .lib.preprocess import h36m_coco_format, revise_kpts

# Make commonly used classes/functions available at package level
__all__ = [
# Inference API
"FMPose3DInference",
"HRNetEstimator",
"Pose2DResult",
"Pose3DResult",
"Source",
# Configuration
"FMPose3DConfig",
"HRNetConfig",
"InferenceConfig",
"ModelConfig",
"PipelineConfig",
# Aggregation methods
"average_aggregation",
"aggregation_select_single_best_hypothesis_by_2D_error",
"aggregation_RPEA_joint_level",
# 2D pose detection
"HRNetPose2d",
"gen_video_kpts",
"h36m_coco_format",
"revise_kpts",
Expand Down
4 changes: 0 additions & 4 deletions fmpose3d/aggregation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,13 @@ def aggregation_RPEA_joint_level(
dist[:, :, 0] = 0.0

# Convert 2D losses to weights using softmax over top-k hypotheses per joint
tau = float(getattr(args, "weight_softmax_tau", 1.0))
H = dist.size(1)
k = int(getattr(args, "topk", None))
# print("k:", k)
# k = int(H//2)+1
k = max(1, min(k, H))

# top-k smallest distances along hypothesis dim
topk_vals, topk_idx = torch.topk(dist, k=k, dim=1, largest=False) # (B,k,J)

# Weight calculation method ; weight_method = 'exp'
temp = args.exp_temp
max_safe_val = temp * 20
topk_vals_clipped = torch.clamp(topk_vals, max=max_safe_val)
Expand Down
7 changes: 2 additions & 5 deletions fmpose3d/animals/common/arber_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import glob
import os
import random
import sys

import cv2
import matplotlib.pyplot as plt
Expand All @@ -23,10 +22,8 @@
from torch.utils.data import Dataset
from tqdm import tqdm

sys.path.append(os.path.dirname(sys.path[0]))

from common.camera import normalize_screen_coordinates
from common.lifter3d import load_camera_params, load_h5_keypoints
from fmpose3d.common.camera import normalize_screen_coordinates
from fmpose3d.animals.common.lifter3d import load_camera_params, load_h5_keypoints


class ArberDataset(Dataset):
Expand Down
Loading
Loading