Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/funtracks/data_model/solution_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def export_tracks(
# Import here to avoid circular imports
from funtracks.import_export.csv._export import export_to_csv

export_to_csv(self, outfile, node_ids)
export_to_csv(self, outfile=outfile, node_ids=node_ids)

def get_track_neighbors(
self, track_id: int, time: int
Expand Down
174 changes: 132 additions & 42 deletions src/funtracks/import_export/csv/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import numpy as np
import pandas as pd
import tifffile
from skimage.util import map_array

from .._utils import filter_graph_with_ancestors

Expand All @@ -16,23 +19,32 @@
def export_to_csv(
tracks: SolutionTracks,
outfile: Path | str,
color_dict: dict[int, np.ndarray] | None = None,
node_ids: set[int] | None = None,
use_display_names: bool = False,
export_seg: bool = False,
seg_path: Path | str | None = None,
) -> None:
"""Export tracks to a CSV file.
TODO: export_all = False for backward compatibility - display names option shouldn't
change which columns are exported, just using which names

Exports tracking data to CSV format with columns for node ID, parent ID,
and all registered features.
and all registered features. Optionally also exports the segmentation, relabeled by
tracklet ID, as tif. If a color dictionary is provided, it will also export the
tracklet colors.

Args:
tracks: SolutionTracks object containing the tracking data to export
outfile: Path to output CSV file
color_dict: dict[int, np.ndarray], optional. If provided, will be used to save the
hex colors.
node_ids: Optional set of node IDs to include. If provided, only these
nodes and their ancestors will be included in the output.
use_display_names: If True, use feature display names as column headers.
If False (default), use raw feature keys for backward compatibility.
export_seg (bool): whether to export the segmentation, relabeled by tracklet ID
seg_path (Path | str, optional): path to save segmentation file to, if requested.

Example:
>>> from funtracks.import_export import export_to_csv
Expand All @@ -41,6 +53,8 @@ def export_to_csv(
>>> export_to_csv(tracks, "output.csv", use_display_names=True)
>>> # Export only specific nodes
>>> export_to_csv(tracks, "filtered.csv", node_ids={1, 2, 3})
>>> # Export with segmentation
>>> export_to_csv(tracks, "filtered.csv", export_seg=True, seg_path="seg.tif")
"""

def convert_numpy_to_python(value):
Expand All @@ -51,43 +65,62 @@ def convert_numpy_to_python(value):
return int(value)
return value

header: list[str] = []
column_map: dict[str, str | list[str]] = {}

# Build header - use old hardcoded format for backward compatibility
if use_display_names:
header = ["ID", "Parent ID"]
header.extend(["ID", "Parent ID"])
column_map["id"] = "ID"
column_map["parent_id"] = "Parent ID"
else:
# Backward compatibility: use old column names
# Old format: t, [z], y, x, id, parent_id, track_id
header = ["t"]
if tracks.ndim == 4:
header.extend(["z", "y", "x"])
else: # ndim == 3
header.extend(["y", "x"])
# time
header.append("t")
column_map["time"] = "t"

# spatial coordinates
coords = ["z", "y", "x"] if tracks.ndim == 4 else ["y", "x"]

header.extend(coords)
column_map["coords"] = coords

# identifiers
header.extend(["id", "parent_id", "track_id"])
column_map["id"] = "id"
column_map["parent_id"] = "parent_id"
column_map["track_id"] = "track_id"

# For display names mode, build dynamic header from features
feature_names = []
if use_display_names:
for feature_name, feature_dict in tracks.features.items():
feature_names.append(feature_name)
num_values = feature_dict.get("num_values", 1)

if num_values > 1:
# Multi-value feature: use value_names if available
value_names = feature_dict.get("value_names")
if value_names is not None:
header.extend(value_names)
names = list(value_names)
else:
# Fall back to display_name or feature_name with index suffix
base_name = feature_dict.get("display_name", feature_name)
if len(base_name) == num_values:
if (
isinstance(base_name, (list, tuple))
and len(base_name) == num_values
):
# use list elements
header.extend(list(base_name))
names = list(base_name)
else:
# use a suffix
header.extend([f"{base_name}_{i}" for i in range(num_values)])
names = [f"{base_name}_{i}" for i in range(num_values)]
header.extend(names)
else:
# Single-value feature: use display_name or feature_name
col_name = feature_dict.get("display_name", feature_name)
header.append(col_name)
names = feature_dict.get("display_name", feature_name)
header.extend([names])

column_map[feature_name] = names

# Determine which nodes to export
if node_ids is None:
Expand All @@ -96,29 +129,86 @@ def convert_numpy_to_python(value):
node_to_keep = filter_graph_with_ancestors(tracks.graph, node_ids)

# Write CSV file
with open(outfile, "w") as f:
f.write(",".join(header))
for node_id in node_to_keep:
parents = list(tracks.graph.predecessors(node_id))
parent_id = "" if len(parents) == 0 else parents[0]

if use_display_names:
# Dynamic feature export
features: list[Any] = []
for feature_name in feature_names:
feature_value = tracks.get_node_attr(node_id, feature_name)
if isinstance(feature_value, list | tuple):
features.extend(feature_value)
else:
features.append(feature_value)
row = [node_id, parent_id, *features]
else:
# Backward compatibility: hardcoded format matching old behavior
time = tracks.get_time(node_id)
position = tracks.get_position(node_id)
track_id = tracks.get_track_id(node_id)
row = [time, *position, node_id, parent_id, track_id]

row = [convert_numpy_to_python(value) for value in row]
f.write("\n")
f.write(",".join(map(str, row)))
rows: list[dict[str, Any]] = []

for node_id in node_to_keep:
parents = list(tracks.graph.predecessors(node_id))
parent_id = "" if len(parents) == 0 else parents[0]

row: dict[str, Any]

row = {}
row[cast(str, column_map["id"])] = node_id
row[cast(str, column_map["parent_id"])] = parent_id

if use_display_names:
for feature_name in feature_names:
value = tracks.get_node_attr(node_id, feature_name)
cols = column_map[feature_name]
if isinstance(cols, list):
assert isinstance(value, (list, tuple))
for col, v in zip(cols, value, strict=True):
row[col] = convert_numpy_to_python(v)
else:
row[cols] = convert_numpy_to_python(value)

else:
row[cast(str, column_map["time"])] = convert_numpy_to_python(
tracks.get_time(node_id)
)

pos = tracks.get_position(node_id)
for name, value in zip(column_map["coords"], pos, strict=True):
row[name] = convert_numpy_to_python(value)

row[cast(str, column_map["track_id"])] = tracks.get_track_id(node_id)

rows.append(row)

df = pd.DataFrame(rows)
df = df[header]

# Also add a column with the track ID color
if color_dict is not None:

def rgb_to_hex(rgb):
"""Convert [R, G, B] to #RRGGBB."""
r, g, b = [int(round(c * 255)) for c in rgb[:3]] # scale and convert to int
return f"#{r:02x}{g:02x}{b:02x}"

track_id_to_hex = {}

for track_id, nodes in tracks.track_id_to_node.items():
if not nodes:
continue
first_node = nodes[0]
rgb = color_dict[first_node]
track_id_to_hex[track_id] = rgb_to_hex(rgb)

df_colors = pd.DataFrame(
list(track_id_to_hex.items()), # convert dict to list of (track_id, hex)
columns=[column_map["track_id"], "Tracklet ID Color"],
)

df = pd.merge(df, df_colors, how="left", on=[column_map["track_id"]])

df.to_csv(outfile, index=False)

if export_seg:
# Determine maximum value in the column to assign bit depth
max_val = int(df[column_map["track_id"]].max())

# Pick dtype based on max_val
if max_val <= np.iinfo(np.uint8).max:
dtype = np.uint8
elif max_val <= np.iinfo(np.uint16).max:
dtype = np.uint16
elif max_val <= np.iinfo(np.uint32).max:
dtype = np.uint32
else:
dtype = np.uint64 # large values

input_vals = np.array(df[column_map["id"]])
output_vals = np.array(df[column_map["track_id"]], dtype=dtype)
relabeled_seg = map_array(tracks.segmentation, input_vals, output_vals)
tifffile.imwrite(seg_path, relabeled_seg, compression="deflate")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to docstring and user docs that it writes as tiff

35 changes: 33 additions & 2 deletions tests/import_export/test_csv_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import pytest
import tifffile

from funtracks.import_export import export_to_csv

Expand All @@ -25,12 +27,41 @@ def test_export_solution_to_csv(get_tracks, tmp_path, ndim, expected_header):

# Check first data line (node 1: t=0, pos=[50, 50] or [50, 50, 50], track_id=1)
if ndim == 3:
expected_line1 = ["0", "50", "50", "1", "", "1"]
expected_line1 = ["0", "50.0", "50.0", "1", "", "1"]
else:
expected_line1 = ["0", "50", "50", "50", "1", "", "1"]
expected_line1 = ["0", "50.0", "50.0", "50.0", "1", "", "1"]
assert lines[1].strip().split(",") == expected_line1


@pytest.mark.parametrize(
("ndim", "expected_header"),
[
(3, ["t", "y", "x", "id", "parent_id", "track_id"]),
(4, ["t", "z", "y", "x", "id", "parent_id", "track_id"]),
],
ids=["2d", "3d"],
)
def test_export_solution_to_csv_with_seg(get_tracks, tmp_path, ndim, expected_header):
"""Test exporting tracks to CSV + relabeled segmentation."""
tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True)
temp_file = tmp_path / "test_export.csv"
seg_file = tmp_path / "test_export.tif"
export_to_csv(tracks, temp_file, export_seg=True, seg_path=seg_file)

with open(temp_file) as f:
lines = f.readlines()

assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header
assert lines[0].strip().split(",") == expected_header

# check the segmentation
seg = tifffile.imread(seg_file)
if ndim == 3:
np.all(seg[2, 0:4, 0:4] == 3) # node id was 4, should be relabeled to track id 3
if ndim == 4:
np.all(seg[2, 0:4, 0:4, 0:4] == 3)


def test_export_with_display_names(get_tracks, tmp_path):
"""Test exporting with display names."""
tracks = get_tracks(ndim=3, with_seg=False, is_solution=True)
Expand Down