diff --git a/src/funtracks/data_model/__init__.py b/src/funtracks/data_model/__init__.py index f02d7294..ef0f4187 100644 --- a/src/funtracks/data_model/__init__.py +++ b/src/funtracks/data_model/__init__.py @@ -1,4 +1,2 @@ from .tracks import Tracks # noqa from .solution_tracks import SolutionTracks # noqa -from .tracks_controller import TracksController # noqa -from .graph_attributes import NodeType, NodeAttr, EdgeAttr # noqa diff --git a/src/funtracks/data_model/graph_attributes.py b/src/funtracks/data_model/graph_attributes.py deleted file mode 100644 index 3460b5e4..00000000 --- a/src/funtracks/data_model/graph_attributes.py +++ /dev/null @@ -1,87 +0,0 @@ -import warnings -from enum import Enum, EnumMeta - - -class DeprecatedEnumMeta(EnumMeta): - """Metaclass for deprecated enums that issues warnings on member access.""" - - def __getattribute__(cls, name): - """Issue deprecation warning when accessing enum members.""" - # Get the attribute first to avoid blocking access - value = super().__getattribute__(name) - - # Issue warning only for actual enum members (not special attributes) - if ( - not name.startswith("_") - and name not in ("name", "value") - and isinstance(value, cls) - ): - enum_name = cls.__name__ - - # Customize message based on enum type - if enum_name == "NodeType": - message = ( - f"NodeType.{name} is deprecated and will be removed in " - "funtracks v2.0. This is a visualization concern and " - "should be moved to motile_tracker." - ) - elif enum_name == "NodeAttr": - message = ( - f"NodeAttr.{name} is deprecated and will be removed in " - "funtracks v2.0. Use string keys from tracks.features " - "instead (e.g., tracks.features.position_key)." - ) - elif enum_name == "EdgeAttr": - message = ( - f"EdgeAttr.{name} is deprecated and will be removed in " - "funtracks v2.0. Use string keys directly (e.g., 'iou')." - ) - else: - message = ( - f"{enum_name}.{name} is deprecated and will be removed " - "in funtracks v2.0." - ) - - warnings.warn(message, DeprecationWarning, stacklevel=2) - - return value - - -class NodeAttr(Enum, metaclass=DeprecatedEnumMeta): - """Node attributes that can be added to candidate graph. - - .. deprecated:: 2.0 - NodeAttr enum will be removed in funtracks v2.0. Use string keys from - tracks.features instead (e.g., tracks.features.position_key, "area", etc.). - """ - - POS = "pos" - TIME = "time" - AREA = "area" - TRACK_ID = "track_id" - SEG_ID = "seg_id" - - -class EdgeAttr(Enum, metaclass=DeprecatedEnumMeta): - """Edge attributes that can be added to candidate graph. - - .. deprecated:: 2.0 - EdgeAttr enum will be removed in funtracks v2.0. Use string keys directly - (e.g., "iou"). - """ - - IOU = "iou" - - -class NodeType(Enum, metaclass=DeprecatedEnumMeta): - """Types of nodes in the track graph. Currently used for standardizing - visualization. All nodes are exactly one type. - - .. deprecated:: 2.0 - NodeType will be removed in funtracks v2.0. This is a visualization - concern and should be moved to motile_tracker. - """ - - SPLIT = "SPLIT" - END = "END" - CONTINUE = "CONTINUE" diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 181d968d..3c186af8 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from typing import TYPE_CHECKING import networkx as nx @@ -11,8 +10,6 @@ from .tracks import Tracks if TYPE_CHECKING: - from pathlib import Path - from funtracks.annotators import TrackAnnotator from .tracks import Node @@ -74,21 +71,6 @@ def __init__( self.track_annotator = self._get_track_annotator() - def _initialize_track_ids(self) -> None: - """Initialize track IDs for all nodes. - - Deprecated: - This method is deprecated and will be removed in funtracks v2.0. - Track IDs are now auto-computed during SolutionTracks initialization. - """ - warnings.warn( - "`_initialize_track_ids` is deprecated and will be removed in funtracks v2.0." - " Track IDs are now auto-computed during SolutionTracks initialization.", - DeprecationWarning, - stacklevel=2, - ) - self.enable_features([self.features.tracklet_key]) # type: ignore - def _get_track_annotator(self) -> TrackAnnotator: """Get the TrackAnnotator instance from the annotator registry. @@ -136,16 +118,6 @@ def max_track_id(self) -> int: def track_id_to_node(self) -> dict[int, list[int]]: return self.track_annotator.tracklet_id_to_nodes - @property - def node_id_to_track_id(self) -> dict[Node, int]: - warnings.warn( - "node_id_to_track_id property will be removed in funtracks v2. " - "Use `get_track_id` instead for better performance.", - DeprecationWarning, - stacklevel=2, - ) - return nx.get_node_attributes(self.graph, self.features.tracklet_key) - def get_next_track_id(self) -> int: """Return the next available track_id and update max_tracklet_id in TrackAnnotator @@ -162,35 +134,6 @@ def get_track_id(self, node) -> int: track_id = self.get_node_attr(node, self.features.tracklet_key, required=True) return track_id - def export_tracks( - self, outfile: Path | str, node_ids: set[int] | None = None - ) -> None: - """Export the tracks from this run to a csv with the following columns: - t,[z],y,x,id,parent_id,track_id - Cells without a parent_id will have an empty string for the parent_id. - Whether or not to include z is inferred from self.ndim - - Args: - outfile (Path): path to output csv file - node_ids (set[int], optional): nodes to be included. If provided, only these - nodes and their ancestors will be included in the output. - - .. deprecated:: 1.0 - `SolutionTracks.export_tracks()` is deprecated and will be removed in v2.0. - Use :func:`funtracks.import_export.export_to_csv` instead. - """ - warnings.warn( - "SolutionTracks.export_tracks() is deprecated and will be removed in v2.0. " - "Use funtracks.import_export.export_to_csv() instead.", - DeprecationWarning, - stacklevel=2, - ) - - # Import here to avoid circular imports - from funtracks.import_export.csv._export import export_to_csv - - export_to_csv(self, outfile, node_ids) - def get_track_neighbors( self, track_id: int, time: int ) -> tuple[Node | None, Node | None]: diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 9fc8f9dd..1ae44b4e 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import warnings from collections.abc import Iterable, Sequence from typing import ( TYPE_CHECKING, @@ -13,13 +12,10 @@ import networkx as nx import numpy as np from psygnal import Signal -from skimage import measure from funtracks.features import Feature, FeatureDict, Position, Time if TYPE_CHECKING: - from pathlib import Path - from funtracks.actions import BasicAction from funtracks.annotators import AnnotatorRegistry, GraphAnnotator @@ -126,26 +122,6 @@ def __init__( else: self._setup_core_computed_features() - @property - def time_attr(self): - warn( - "Deprecating Tracks.time_attr in favor of tracks.features.time_key." - " Will be removed in funtracks v2.0.", - DeprecationWarning, - stacklevel=2, - ) - return self.features.time_key - - @property - def pos_attr(self): - warn( - "Deprecating Tracks.pos_attr in favor of tracks.features.position_key." - " Will be removed in funtracks v2.0.", - DeprecationWarning, - stacklevel=2, - ) - return self.features.position_key - def _get_feature_set( self, time_attr: str | None, @@ -390,7 +366,6 @@ def set_positions( self, nodes: Iterable[Node], positions: np.ndarray, - incl_time: bool = False, ): """Set the location of nodes in the graph. Optionally include the time frame as the first dimension. Raises an error if any of the nodes @@ -408,10 +383,6 @@ def set_positions( if not isinstance(positions, np.ndarray): positions = np.array(positions) - if incl_time: - times = positions[:, 0].tolist() # we know this is a list of ints - self.set_times(nodes, times) # type: ignore - positions = positions[:, 1:] if isinstance(self.features.position_key, list): for idx, key in enumerate(self.features.position_key): @@ -419,12 +390,8 @@ def set_positions( else: self._set_nodes_attr(nodes, self.features.position_key, positions.tolist()) - def set_position( - self, node: Node, position: list | np.ndarray, incl_time: bool = False - ): - self.set_positions( - [node], np.expand_dims(np.array(position), axis=0), incl_time=incl_time - ) + def set_position(self, node: Node, position: list | np.ndarray): + self.set_positions([node], np.expand_dims(np.array(position), axis=0)) def get_times(self, nodes: Iterable[Node]) -> Sequence[int]: return self.get_nodes_attr(nodes, self.features.time_key, required=True) @@ -441,109 +408,6 @@ def get_time(self, node: Node) -> int: """ return int(self.get_times([node])[0]) - def set_times(self, nodes: Iterable[Node], times: Iterable[int]): - times = [int(t) for t in times] - self._set_nodes_attr(nodes, self.features.time_key, times) - - def set_time(self, node: Any, time: int): - """Set the time frame of a given node. Raises an error if the node - is not in the graph. - - Args: - node (Any): The node id to set the time frame for - time (int): The time to set - - """ - self.set_times([node], [int(time)]) - - def get_areas(self, nodes: Iterable[Node]) -> Sequence[int | None]: - """Get the area/volume of a given node. Raises a KeyError if the node - is not in the graph. Returns None if the given node does not have an Area - attribute. - - .. deprecated:: 1.0 - `get_areas` will be removed in funtracks v2.0. - Use `get_nodes_attr(nodes, "area")` instead. - - Args: - node (Node): The node id to get the area/volume for - - Returns: - int: The area/volume of the node - """ - warnings.warn( - "`get_areas` is deprecated and will be removed in funtracks v2.0. " - "Use `get_nodes_attr(nodes, 'area')` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.get_nodes_attr(nodes, "area") - - def get_area(self, node: Node) -> int | None: - """Get the area/volume of a given node. Raises a KeyError if the node - is not in the graph. Returns None if the given node does not have an Area - attribute. - - .. deprecated:: 1.0 - `get_area` will be removed in funtracks v2.0. - Use `get_node_attr(node, "area")` instead. - - Args: - node (Node): The node id to get the area/volume for - - Returns: - int: The area/volume of the node - """ - warnings.warn( - "`get_area` is deprecated and will be removed in funtracks v2.0. " - "Use `get_node_attr(node, 'area')` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.get_areas([node])[0] - - def get_ious(self, edges: Iterable[Edge]): - """Get the IoU values for the given edges. - - .. deprecated:: 1.0 - `get_ious` will be removed in funtracks v2.0. - Use `get_edges_attr(edges, "iou")` instead. - - Args: - edges: An iterable of edges to get IoU values for. - - Returns: - The IoU values for the edges. - """ - warnings.warn( - "`get_ious` is deprecated and will be removed in funtracks v2.0. " - "Use `get_edges_attr(edges, 'iou')` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.get_edges_attr(edges, "iou") - - def get_iou(self, edge: Edge): - """Get the IoU value for the given edge. - - .. deprecated:: 1.0 - `get_iou` will be removed in funtracks v2.0. - Use `get_edge_attr(edge, "iou")` instead. - - Args: - edge: An edge to get the IoU value for. - - Returns: - The IoU value for the edge. - """ - warnings.warn( - "`get_iou` is deprecated and will be removed in funtracks v2.0. " - "Use `get_edge_attr(edge, 'iou')` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.get_edge_attr(edge, "iou") - def get_pixels(self, node: Node) -> tuple[np.ndarray, ...] | None: """Get the pixels corresponding to each node in the nodes list. @@ -577,37 +441,6 @@ def set_pixels(self, pixels: tuple[np.ndarray, ...], value: int) -> None: raise ValueError("Cannot set pixels when segmentation is None") self.segmentation[pixels] = value - def _set_node_attributes(self, node: Node, attributes: dict[str, Any]) -> None: - """Set the attributes for the given node - - Args: - node (Node): The node to set the attributes for - attributes (dict[str, Any]): A mapping from attribute name to value - """ - if node in self.graph: - for key, value in attributes.items(): - self.graph.nodes[node][key] = value - else: - logger.info("Node %d not found in the graph.", node) - - def _set_edge_attributes(self, edge: Edge, attributes: dict[str, Any]) -> None: - """Set the edge attributes for the given edges. Attributes should already exist - (although adding will work in current implementation, they cannot currently be - removed) - - Args: - edges (list[Edge]): A list of edges to set the attributes for - attributes (Attributes): A dictionary of attribute name -> numpy array, - where the length of the arrays matches the number of edges. - Attributes should already exist: this function will only - update the values. - """ - if self.graph.has_edge(*edge): - for key, value in attributes.items(): - self.graph.edges[edge][key] = value - else: - logger.info("Edge %s not found in the graph.", edge) - def _compute_ndim( self, seg: np.ndarray | None, @@ -648,25 +481,9 @@ def get_node_attr(self, node: Node, attr: str, required: bool = False): else: return self.graph.nodes[node].get(attr, None) - def _get_node_attr(self, node, attr, required=False): - warnings.warn( - "_get_node_attr deprecated in favor of public method get_node_attr", - DeprecationWarning, - stacklevel=2, - ) - return self.get_node_attr(node, attr, required=required) - def get_nodes_attr(self, nodes: Iterable[Node], attr: str, required: bool = False): return [self.get_node_attr(node, attr, required=required) for node in nodes] - def _get_nodes_attr(self, nodes, attr, required=False): - warnings.warn( - "_get_nodes_attr deprecated in favor of public method get_nodes_attr", - DeprecationWarning, - stacklevel=2, - ) - return self.get_nodes_attr(nodes, attr, required=required) - def _set_edge_attr(self, edge: Edge, attr: str, value: Any): self.graph.edges[edge][attr] = value @@ -748,111 +565,3 @@ def disable_features(self, feature_keys: list[str]) -> None: for key in feature_keys: if key in self.features: del self.features[key] - - # ========== Persistence ========== - - def save(self, directory: Path): - """Save the tracks to the given directory. - Currently, saves the graph as a json file in networkx node link data format, - saves the segmentation as a numpy npz file, and saves the time and position - attributes and scale information in an attributes json file. - Args: - directory (Path): The directory to save the tracks in. - """ - warn( - "`Tracks.save` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.save` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import save_tracks - - save_tracks(self, directory) - - @classmethod - def load(cls, directory: Path, seg_required=False, solution=False) -> Tracks: - """Load a Tracks object from the given directory. Looks for files - in the format generated by Tracks.save. - Args: - directory (Path): The directory containing tracks to load - seg_required (bool, optional): If true, raises a FileNotFoundError if the - segmentation file is not present in the directory. Defaults to False. - Returns: - Tracks: A tracks object loaded from the given directory - """ - warn( - "`Tracks.load` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.load` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import load_tracks - - return load_tracks(directory, seg_required=seg_required, solution=solution) - - @classmethod - def delete(cls, directory: Path): - """Delete the tracks in the given directory. Also deletes the directory. - - Args: - directory (Path): Directory containing tracks to be deleted - """ - warn( - "`Tracks.delete` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.delete` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import delete_tracks - - delete_tracks(directory) - - def _compute_node_attrs(self, node: Node, time: int) -> dict[str, Any]: - """Get the segmentation controlled node attributes (area and position) - from the segmentation with label based on the node id in the given time point. - - Args: - node (int): The node id to query the current segmentation for - time (int): The time frame of the current segmentation to query - - Returns: - dict[str, int]: A dictionary containing the attributes that could be - determined from the segmentation. It will be empty if self.segmentation - is None. If self.segmentation exists but node id is not present in time, - area will be 0 and position will be None. If self.segmentation - exists and node id is present in time, area and position will be included. - - Deprecated: - This method is deprecated and will be removed in funtracks v2.0. - Use the annotator system (enable_features) to compute node attributes instead. - """ - warn( - "`_compute_node_attrs` is deprecated and will be removed in funtracks v2.0. " - "Use the annotator system (enable_features) to compute node attributes " - "instead.", - DeprecationWarning, - stacklevel=2, - ) - if self.segmentation is None: - return {} - - attrs: dict[str, Any] = {} - seg = self.segmentation[time] == node - pos_scale = self.scale[1:] if self.scale is not None else None - area = np.sum(seg) - if pos_scale is not None: - area *= np.prod(pos_scale) - # only include the position if the segmentation was actually there - pos = ( - measure.centroid(seg, spacing=pos_scale) # type: ignore - if area > 0 - else np.array( - [ - None, - ] - * (self.ndim - 1) - ) - ) - attrs["area"] = area - attrs["pos"] = pos - return attrs diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py deleted file mode 100644 index e8fe95e7..00000000 --- a/src/funtracks/data_model/tracks_controller.py +++ /dev/null @@ -1,395 +0,0 @@ -from __future__ import annotations - -import warnings -from typing import TYPE_CHECKING -from warnings import warn - -from ..actions import ( - Action, - ActionGroup, - UpdateNodeAttrs, -) -from ..actions.action_history import ActionHistory -from ..user_actions import ( - UserAddEdge, - UserAddNode, - UserDeleteEdge, - UserDeleteNode, - UserUpdateSegmentation, -) -from .solution_tracks import SolutionTracks -from .tracks import Attrs, Edge, Node, SegMask - -if TYPE_CHECKING: - from collections.abc import Iterable - - -class TracksController: - """A set of high level functions to change the data model. - All changes to the data should go through this API. - """ - - def __init__(self, tracks: SolutionTracks): - warnings.warn( - "TracksController deprecated in favor of directly calling UserActions and" - "will be removed in funtracks v2. You will need to keep the action history " - "in your application and emit the tracks refresh.", - DeprecationWarning, - stacklevel=2, - ) - self.tracks = tracks - self.action_history = ActionHistory() - self.node_id_counter = 1 - - def add_nodes( - self, - attributes: Attrs, - pixels: list[SegMask] | None = None, - force: bool = False, - ) -> None: - """Calls the _add_nodes function to add nodes. Calls the refresh signal when - finished. - - Args: - attributes (Attrs): dictionary containing at least time and position - attributes - pixels (list[SegMask] | None, optional): The pixels associated with each - node, if a segmentation is present. Defaults to None. - force (bool): Whether to force the operation by removing conflicting edges. - Defaults to False. - - """ - result = self._add_nodes(attributes, pixels, force) - if result is not None: - action, nodes = result - self.action_history.add_new_action(action) - self.tracks.refresh.emit(nodes[0] if nodes else None) - - def _add_nodes( - self, - attributes: Attrs, - pixels: list[SegMask] | None = None, - force: bool = False, - ) -> tuple[Action, list[Node]] | None: - """Add nodes to the graph. Includes all attributes and the segmentation. - Will return the actions needed to add the nodes, and the node ids generated for - the new nodes. - If there is a segmentation, the attributes must include: - - time - - node_id - - track_id - If there is not a segmentation, the attributes must include: - - time - - pos - - track_id - - Logic of the function: - - remove edges (when we add a node in a track between two nodes - connected by a skip edge) - - add the nodes - - add edges (to connect each node to its immediate - predecessor and successor with the same track_id, if any) - - Args: - attributes (Attrs): dictionary containing at least time and track id, - and either node_id (if pixels are provided) or position (if not) - pixels (list[SegMask] | None): A list of pixels associated with the node, - or None if there is no segmentation. These pixels will be updated - in the tracks.segmentation, set to the new node id. - force (bool): Whether to force the operation by removing conflicting edges. - Defaults to False. - """ - times = attributes[self.tracks.features.time_key] - nodes: list[Node] - if pixels is not None: - nodes = attributes["node_id"] - else: - nodes = self._get_new_node_ids(len(times)) - actions: list[ActionGroup | Action] = [] - nodes_added = [] - for i in range(len(nodes)): - actions.append( - UserAddNode( - self.tracks, - node=nodes[i], - attributes={key: val[i] for key, val in attributes.items()}, - pixels=pixels[i] if pixels is not None else None, - force=force, - ) - ) - nodes_added.append(nodes[i]) - - return ActionGroup(self.tracks, actions), nodes_added - - def delete_nodes(self, nodes: Iterable[Node]) -> None: - """Calls the _delete_nodes function and then emits the refresh signal - - Args: - nodes (Iterable[Node]): array of node_ids to be deleted - """ - - action = self._delete_nodes(nodes) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def _delete_nodes( - self, nodes: Iterable[Node], pixels: Iterable[SegMask] | None = None - ) -> Action: - """Delete the nodes provided by the array from the graph but maintain successor - track_ids. Reconnect to the nearest predecessor and/or nearest successor - on the same track, if any. - - Function logic: - - delete all edges incident to the nodes - - delete the nodes - - add edges to preds and succs of nodes if they have the same track id - - update track ids if we removed a division by deleting the dge - - Args: - nodes (Iterable[Node]): array of node_ids to be deleted - pixels (Iterable[SegMask] | None): pixels of the nodes to be deleted, if - known already. Will be computed if not provided. - """ - actions: list[ActionGroup | Action] = [] - pixels = list(pixels) if pixels is not None else None - for i, node in enumerate(nodes): - actions.append( - UserDeleteNode( - self.tracks, - node, - pixels=pixels[i] if pixels is not None else None, - ) - ) - return ActionGroup(self.tracks, actions) - - def add_edges(self, edges: Iterable[Edge], force: bool = False) -> None: - """Add edges to the graph. Also update the track ids and - corresponding segmentations if applicable - - Args: - edges (Iterable[Edge]): An iterable of edges, each with source and target - node ids - force (bool): Whether to force this operation by removing conflicting edges. - Defaults to False. - """ - for edge in edges: - is_valid = self.is_valid(edge) - if not is_valid: - # warning was printed with details in is_valid call - return - - action: Action - action = self._add_edges(edges, force) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def update_node_attrs(self, nodes: Iterable[Node], attributes: Attrs): - """Update the user provided node attributes (not the managed attributes). - Also adds the action to the history and emits the refresh signal. - - Args: - nodes (Iterable[Node]): The nodes to update the attributes for - attributes (Attrs): A mapping from user-provided attributes to values for - each node. - """ - action = self._update_node_attrs(nodes, attributes) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def _update_node_attrs(self, nodes: Iterable[Node], attributes: Attrs) -> Action: - """Update the user provided node attributes (not the managed attributes). - - Args: - nodes (Iterable[Node]): The nodes to update the attributes for - attributes (Attrs): A mapping from user-provided attributes to values for - each node. - - Returns: An Action object that performed the update - """ - actions: list[ActionGroup | Action] = [] - for i, node in enumerate(nodes): - actions.append( - UpdateNodeAttrs( - self.tracks, node, {key: val[i] for key, val in attributes.items()} - ) - ) - return ActionGroup(self.tracks, actions) - - def _add_edges(self, edges: Iterable[Edge], force: bool = False) -> ActionGroup: - """Add edges and attributes to the graph. Also update the track ids of the - target node tracks and potentially sibling tracks. - - Args: - edges (Iterable[edge]): An iterable of edges, each with source and target - node ids - force (bool): Whether to force this action by removing conflicting edges. - - Returns: - An Action containing all edits performed in this call - """ - actions: list[ActionGroup | Action] = [] - for edge in edges: - actions.append(UserAddEdge(self.tracks, edge, force)) - return ActionGroup(self.tracks, actions) - - def is_valid(self, edge: Edge) -> bool: - """Check if this edge is valid. - Criteria: - - not horizontal - - not existing yet - - no triple divisions - - new edge should be the shortest possible connection between two nodes, given - their track_ids (no skipping/bypassing any nodes of the same track_id). - Check if there are any nodes of the same source or target track_id between - source and target - - Args: - edge (Edge): edge to be validated - - Returns: - True if the edge is valid, false if invalid""" - - # make sure that the node2 is downstream of node1 - time1 = self.tracks.get_time(edge[0]) - time2 = self.tracks.get_time(edge[1]) - - if time1 > time2: - edge = (edge[1], edge[0]) - time1, time2 = time2, time1 - # do all checks - # reject if edge already exists - if self.tracks.graph.has_edge(edge[0], edge[1]): - warn("Edge is rejected because it exists already.", stacklevel=2) - return False - - # reject if edge is horizontal - elif self.tracks.get_time(edge[0]) == self.tracks.get_time(edge[1]): - warn("Edge is rejected because it is horizontal.", stacklevel=2) - return False - - elif self.tracks.graph.out_degree(edge[0]) > 1: - warn( - "Edge is rejected because triple divisions are currently not allowed.", - stacklevel=2, - ) - return False - - elif time2 - time1 > 1: - track_id2 = self.tracks.get_track_id(edge[1]) - # check whether there are already any nodes with the same track id between - # source and target (shortest path between equal track_ids rule) - for t in range(time1 + 1, time2): - nodes = [ - n - for n in self.tracks.nodes() - if self.tracks.get_time(n) == t - and self.tracks.get_track_id(n) == track_id2 - ] - if len(nodes) > 0: - warn("Please connect to the closest node", stacklevel=2) - return False - - # all checks passed! - return True - - def delete_edges(self, edges: Iterable[Edge]): - """Delete edges from the graph. - - Args: - edges (Iterable[Edge]): The Nx2 array of edges to be deleted - """ - - for edge in edges: - # First check if the to be deleted edges exist - if not self.tracks.graph.has_edge(edge[0], edge[1]): - warn("Cannot delete non-existing edge!", stacklevel=2) - return - action = self._delete_edges(edges) - self.action_history.add_new_action(action) - self.tracks.refresh.emit() - - def _delete_edges(self, edges: Iterable[Edge]) -> ActionGroup: - actions: list[ActionGroup | Action] = [] - for edge in edges: - actions.append(UserDeleteEdge(self.tracks, edge)) - return ActionGroup(self.tracks, actions) - - def update_segmentations( - self, - new_value: int, - updated_pixels: list[tuple[SegMask, int]], - current_timepoint: int, - current_track_id: int, - force: bool = False, - ): - """Handle a change in the segmentation mask, checking for node addition, - deletion, and attribute updates. - - NOTE: we have introduced a minor breaking change to this API that finn will need - to adapt to - it used to parse the pixel change into different action lists, - but that is now done in the UserUpdateSegmentation action - - Args: - new_value (int)): the label that the user drew with - updated_pixels (list[tuple[SegMask, int]]): a list of pixels changed - and the value that was there before the user drew - current_timepoint (int): the current time point in the viewer, used to set - the selected node. - current_track_id (int): the track_id to use when adding a new node, usually - the currently selected track id in the viewer - force (bool): Whether to force the operation by removing conflicting edges. - Defaults to False. - """ - - action = UserUpdateSegmentation( - self.tracks, new_value, updated_pixels, current_track_id, force - ) - self.action_history.add_new_action(action) - nodes_added = action.nodes_added - times = self.tracks.get_times(nodes_added) - if current_timepoint in times: - node_to_select = nodes_added[times.index(current_timepoint)] - else: - node_to_select = None - self.tracks.refresh.emit(node_to_select) - - def undo(self) -> bool: - """Obtain the action to undo from the history, and invert. - Returns: - bool: True if the action was undone, False if there were no more actions - """ - if self.action_history.undo(): - self.tracks.refresh.emit() - return True - else: - return False - - def redo(self) -> bool: - """Obtain the action to redo from the history - Returns: - bool: True if the action was re-done, False if there were no more actions - """ - if self.action_history.redo(): - self.tracks.refresh.emit() - return True - else: - return False - - def _get_new_node_ids(self, n: int) -> list[Node]: - """Get a list of new node ids for creating new nodes. - They will be unique from all existing nodes, but have no other guarantees. - - Args: - n (int): The number of new node ids to return - - Returns: - list[Node]: A list of new node ids. - """ - ids = [self.node_id_counter + i for i in range(n)] - self.node_id_counter += n - for idx, _id in enumerate(ids): - while self.tracks.graph.has_node(_id): - _id = self.node_id_counter - self.node_id_counter += 1 - ids[idx] = _id - return ids diff --git a/src/funtracks/import_export/__init__.py b/src/funtracks/import_export/__init__.py index 4cb8101c..28313e6a 100644 --- a/src/funtracks/import_export/__init__.py +++ b/src/funtracks/import_export/__init__.py @@ -1,9 +1,9 @@ from ._tracks_builder import TracksBuilder +from ._v1_format import load_v1_tracks from .csv._export import export_to_csv from .csv._import import CSVTracksBuilder, tracks_from_df from .geff._export import export_to_geff from .geff._import import GeffTracksBuilder, import_from_geff -from .internal_format import load_tracks, save_tracks from .magic_imread import magic_imread __all__ = [ @@ -14,7 +14,6 @@ "tracks_from_df", "export_to_csv", "export_to_geff", - "save_tracks", - "load_tracks", + "load_v1_tracks", "magic_imread", ] diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py index 7f63946f..92506a44 100644 --- a/src/funtracks/import_export/_tracks_builder.py +++ b/src/funtracks/import_export/_tracks_builder.py @@ -16,7 +16,6 @@ import numpy as np from geff._typing import InMemoryGeff -from funtracks.data_model.graph_attributes import NodeAttr from funtracks.data_model.solution_tracks import SolutionTracks from funtracks.features import Feature from funtracks.import_export._import_segmentation import ( @@ -464,7 +463,8 @@ def handle_segmentation( return seg_array.compute(), scale # Relabel segmentation: seg_id -> node_id - time_values = node_props[NodeAttr.TIME.value]["values"] + time_attr = "time" + time_values = node_props[time_attr]["values"] new_segmentation = relabel_segmentation( seg_array, graph, node_ids, seg_ids, time_values ) diff --git a/src/funtracks/import_export/internal_format.py b/src/funtracks/import_export/_v1_format.py similarity index 97% rename from src/funtracks/import_export/internal_format.py rename to src/funtracks/import_export/_v1_format.py index a7ffcb63..f6372626 100644 --- a/src/funtracks/import_export/internal_format.py +++ b/src/funtracks/import_export/_v1_format.py @@ -18,8 +18,9 @@ ATTRS_FILE = "attrs.json" -def save_tracks(tracks: Tracks, directory: Path) -> None: - """Save the tracks to the given directory. +def _save_v1_tracks(tracks: Tracks, directory: Path) -> None: + """Only used for testing backward compatibility! + Currently, saves the graph as a json file in networkx node link data format, saves the segmentation as a numpy npz file, and saves the time and position attributes and scale information in an attributes json file. @@ -99,12 +100,14 @@ def _save_attrs(tracks: Tracks, directory: Path) -> None: json.dump(attrs_dict, f) -def load_tracks( +def load_v1_tracks( directory: Path, seg_required: bool = False, solution: bool = False ) -> Tracks | SolutionTracks: """Load a Tracks object from the given directory. Looks for files in the format generated by Tracks.save. + TODO: retain loading capabilities for legacy tracks + Args: directory (Path): The directory containing tracks to load seg_required (bool, optional): If true, raises a FileNotFoundError if the diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 30c5f409..3994d329 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,6 +1,5 @@ import networkx as nx import numpy as np -import pytest from funtracks.actions import AddNode from funtracks.data_model import SolutionTracks, Tracks @@ -28,15 +27,6 @@ def test_next_track_id(graph_2d_with_computed_features): assert tracks.get_next_track_id() == 11 -def test_node_id_to_track_id(graph_2d_with_computed_features): - tracks = SolutionTracks(graph_2d_with_computed_features, ndim=3, **track_attrs) - with pytest.warns( - DeprecationWarning, - match="node_id_to_track_id property will be removed in funtracks v2. ", - ): - tracks.node_id_to_track_id # noqa B018 - - def test_from_tracks_cls(graph_2d_with_computed_features): tracks = Tracks( graph_2d_with_computed_features, @@ -83,55 +73,6 @@ def test_next_track_id_empty(): assert tracks.get_next_track_id() == 1 -def test_export_to_csv( - graph_2d_with_computed_features, graph_3d_with_computed_features, tmp_path -): - # Test backward-compatible default format (use_display_names=False) - tracks = SolutionTracks(graph_2d_with_computed_features, **track_attrs, ndim=3) - temp_file = tmp_path / "test_export_2d.csv" - tracks.export_tracks(temp_file) - with open(temp_file) as f: - lines = f.readlines() - - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header - - # Backward compatible format: t, y, x, id, parent_id, track_id - header = ["t", "y", "x", "id", "parent_id", "track_id"] - assert lines[0].strip().split(",") == header - - tracks = SolutionTracks(graph_3d_with_computed_features, **track_attrs, ndim=4) - temp_file = tmp_path / "test_export_3d.csv" - tracks.export_tracks(temp_file) - with open(temp_file) as f: - lines = f.readlines() - - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header - - # Backward compatible format: t, z, y, x, id, parent_id, track_id - header = ["t", "z", "y", "x", "id", "parent_id", "track_id"] - assert lines[0].strip().split(",") == header - - # Test exporting a selection of nodes. We have 6 nodes in total and we ask to save - # node 4 and 6. Because node 1 and 3 are ancestors of node 4, we expect them to be - # included as well to maintain a valid graph without missing parents. - tracks.export_tracks(temp_file, node_ids=[4, 6]) - with open(temp_file) as f: - lines = f.readlines() - - assert len(lines) == 5 # (4 nodes + 1 header) - - # In backward-compatible format, node ID is 5th column (index 4): t, z, y, x, id, ... - node_ids_in_csv = [int(line.split(",")[4]) for line in lines[1:]] - expected_node_ids = [1, 3, 4, 6] - assert sorted(node_ids_in_csv) == sorted(expected_node_ids), ( - f"Unexpected nodes in CSV: {node_ids_in_csv}" - ) - - # Backward compatible format - header = ["t", "z", "y", "x", "id", "parent_id", "track_id"] - assert lines[0].strip().split(",") == header - - def test_export_to_csv_with_display_names( graph_2d_with_computed_features, graph_3d_with_computed_features, tmp_path ): diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 8476e203..3ba2e5a7 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,8 +1,6 @@ import networkx as nx import numpy as np import pytest -from networkx.utils import graphs_equal -from numpy.testing import assert_array_almost_equal from funtracks.data_model import Tracks @@ -45,7 +43,7 @@ def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] - tracks.set_time(1, 1) + tracks._set_node_attr(1, tracks.features.time_key, 1) assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] tracks_wrong_attr = Tracks( @@ -82,9 +80,6 @@ def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation tracks.set_position(1, [55, 56, 57]) assert tracks.get_position(1) == [55, 56, 57] - tracks.set_position(1, [1, 50, 50, 50], incl_time=True) - assert tracks.get_time(1) == 1 - def test_pixels_and_seg_id(graph_3d_with_computed_features, segmentation_3d): # create track with graph and seg @@ -98,28 +93,6 @@ def test_pixels_and_seg_id(graph_3d_with_computed_features, segmentation_3d): tracks.set_pixels(pix, new_seg_id) -def test_save_load_delete(tmp_path, graph_2d_with_computed_features, segmentation_2d): - tracks_dir = tmp_path / "tracks" - tracks = Tracks(graph_2d_with_computed_features, segmentation_2d, **track_attrs) - with pytest.warns( - DeprecationWarning, - match="`Tracks.save` is deprecated and will be removed in 2.0", - ): - tracks.save(tracks_dir) - with pytest.warns( - DeprecationWarning, - match="`Tracks.load` is deprecated and will be removed in 2.0", - ): - loaded = Tracks.load(tracks_dir) - assert graphs_equal(loaded.graph, tracks.graph) - assert_array_almost_equal(loaded.segmentation, tracks.segmentation) - with pytest.warns( - DeprecationWarning, - match="`Tracks.delete` is deprecated and will be removed in 2.0", - ): - Tracks.delete(tracks_dir) - - def test_nodes_edges(graph_2d_with_computed_features): tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} @@ -148,38 +121,12 @@ def test_predecessors_successors(graph_2d_with_computed_features): assert tracks.successors(2) == [] -def test_area_methods(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) - assert tracks.get_area(1) == 1245 - assert tracks.get_areas([1, 2]) == [1245, 305] - - -def test_iou_methods(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) - assert tracks.get_iou((1, 2)) == 0.0 - assert tracks.get_ious([(1, 2)]) == [0.0] - assert tracks.get_ious([(1, 2), (1, 3)]) == [0.0, 0.395] - - def test_get_set_node_attr(graph_2d_with_computed_features): tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) tracks._set_node_attr(1, "a", 42) - # test deprecated functions - with pytest.warns( - DeprecationWarning, - match="_get_node_attr deprecated in favor of public method get_node_attr", - ): - assert tracks._get_node_attr(1, "a") == 42 tracks._set_nodes_attr([1, 2], "b", [7, 8]) - with pytest.warns( - DeprecationWarning, - match="_get_nodes_attr deprecated in favor of public method get_nodes_attr", - ): - assert tracks._get_nodes_attr([1, 2], "b") == [7, 8] - - # test new functions assert tracks.get_node_attr(1, "a", required=True) == 42 assert tracks.get_nodes_attr([1, 2], "b", required=True) == [7, 8] assert tracks.get_nodes_attr([1, 2], "b", required=False) == [7, 8] @@ -242,30 +189,6 @@ def test_set_positions_list(graph_2d_list): ) -def test_set_node_attributes(graph_2d_with_computed_features, caplog): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) - attrs = {"attr_1": 1, "attr_2": ["a", "b", "c", "d", "e", "f"]} - tracks._set_node_attributes(1, attrs) - assert tracks.get_node_attr(1, "attr_1") == 1 - assert tracks.get_node_attr(1, "attr_2") == ["a", "b", "c", "d", "e", "f"] - with caplog.at_level("INFO"): - tracks._set_node_attributes(7, attrs) - assert any("Node 7 not found in the graph." in message for message in caplog.messages) - - -def test_set_edge_attributes(graph_2d_with_computed_features, caplog): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) - attrs = {"attr_1": 1, "attr_2": ["a", "b", "c", "d"]} - tracks._set_edge_attributes((1, 2), attrs) - assert tracks.get_edge_attr((1, 2), "attr_1") == 1 - assert tracks.get_edge_attr((1, 2), "attr_2") == ["a", "b", "c", "d"] - with caplog.at_level("INFO"): - tracks._set_edge_attributes((4, 6), attrs) - assert any( - "Edge (4, 6) not found in the graph." in message for message in caplog.messages - ) - - def test_get_pixels_and_set_pixels(graph_2d_with_computed_features, segmentation_2d): tracks = Tracks( graph_2d_with_computed_features, segmentation_2d, ndim=3, **track_attrs diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py deleted file mode 100644 index 8047c0e5..00000000 --- a/tests/data_model/test_tracks_controller.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np - -from funtracks.data_model.solution_tracks import SolutionTracks -from funtracks.data_model.tracks_controller import TracksController - - -def test__add_nodes_no_seg(graph_2d_with_computed_features): - # add without segmentation - tracks = SolutionTracks( - graph_2d_with_computed_features, - ndim=3, - time_attr="t", - tracklet_attr="track_id", - ) - controller = TracksController(tracks) - - num_edges = tracks.graph.number_of_edges() - - # start a new track with multiple nodes - attrs = { - "t": [0, 1], - "pos": np.array([[1, 3], [1, 3]]), - "track_id": [6, 6], - } - - action, node_ids = controller._add_nodes(attrs) - - node = node_ids[0] - assert tracks.graph.has_node(node) - assert tracks.get_position(node) == [1, 3] - assert tracks.get_track_id(node) == 6 - - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added - - # add nodes to end of existing track - attrs = { - "t": [2, 3], - "pos": np.array([[1, 3], [1, 3]]), - "track_id": [2, 2], - } - - action, node_ids = controller._add_nodes(attrs) - - node1 = node_ids[0] - node2 = node_ids[1] - assert tracks.get_position(node1) == [1, 3] - assert tracks.get_track_id(node1) == 2 - assert tracks.graph.has_edge(2, node1) - assert tracks.graph.has_edge(node1, node2) - - # add node to middle of existing track - attrs = { - "t": [3], - "pos": np.array([[1, 3]]), - "track_id": [3], - } - - action, node_ids = controller._add_nodes(attrs) - - node = node_ids[0] - assert tracks.get_position(node) == [1, 3] - assert tracks.get_track_id(node) == 3 - - assert tracks.graph.has_edge(4, node) - assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) - - -def test__add_nodes_with_seg(graph_2d_with_computed_features, segmentation_2d): - # add with segmentation - tracks = SolutionTracks( - graph_2d_with_computed_features, - segmentation=segmentation_2d, - time_attr="t", - tracklet_attr="track_id", - ) - controller = TracksController(tracks) - - num_edges = tracks.graph.number_of_edges() - - new_seg = segmentation_2d.copy() - time = 0 - track_id = 6 - node1 = 7 - node2 = 8 - new_seg[time : time + 1, 90:100, 0:4] = node1 - new_seg[time + 1 : time + 2, 90:100, 0:4] = node2 - expected_center = [94.5, 1.5] - # start a new track - attrs = { - "t": [time, time + 1], - "track_id": [track_id, track_id], - "node_id": [node1, node2], - } - - loc_pix = np.where(new_seg[time] == node1) - time_pix = np.ones_like(loc_pix[0]) * time - time_pix2 = np.ones_like(loc_pix[0]) * (time + 1) - pixels = [ - (time_pix, *loc_pix), - (time_pix2, *loc_pix), - ] # TODO: get time from pixels? - - action, node_ids = controller._add_nodes(attrs, pixels=pixels) - - node1, node2 = node_ids - assert tracks.get_time(node1) == 0 - assert tracks.get_position(node1) == expected_center - assert tracks.get_track_id(node1) == 6 - assert tracks.get_time(node2) == 1 - assert tracks.get_position(node2) == expected_center - assert tracks.get_track_id(node2) == 6 - assert np.sum(tracks.segmentation != new_seg) == 0 - - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added - - # add nodes to end of existing track - time = 2 - track_id = 2 - node1 = 9 - node2 = 10 - new_seg[time : time + 1, 0:10, 0:4] = node1 - new_seg[time + 1 : time + 2, 0:10, 0:4] = node2 - expected_center = [4.5, 1.5] - # start a new track - attrs = { - "t": [time, time + 1], - "track_id": [track_id, track_id], - "node_id": [node1, node2], - } - - loc_pix = np.where(new_seg[time] == node1) - time_pix = np.ones_like(loc_pix[0]) * time - time_pix2 = np.ones_like(loc_pix[0]) * (time + 1) - pixels = [(time_pix, *loc_pix), (time_pix2, *loc_pix)] - - _, node_ids = controller._add_nodes(attrs, pixels) - - node = node_ids[0] - assert tracks.get_position(node) == expected_center - assert tracks.get_track_id(node) == 2 - assert np.sum(tracks.segmentation != new_seg) == 0 - - assert tracks.graph.has_edge(2, node) - assert tracks.graph.has_edge(node, node_ids[1]) - - # add node to middle of existing track - time = 3 - track_id = 3 - node1 = 11 - new_seg[time, 0:10, 0:4] = node1 - expected_center = [4.5, 1.5] - attrs = { - "t": [time], - "track_id": [track_id], - "node_id": [node1], - } - - loc_pix = np.where(new_seg[time] == node1) - time_pix = np.ones_like(loc_pix[0]) * time - pixels = [(time_pix, *loc_pix)] - - action, node_ids = controller._add_nodes(attrs, pixels=pixels) - - node = node_ids[0] - assert tracks.get_position(node) == expected_center - assert tracks.get_track_id(node) == 3 - assert np.sum(tracks.segmentation != new_seg) == 0 - - assert tracks.graph.has_edge(4, node) - assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) - - -def test__delete_nodes_no_seg(graph_2d_with_computed_features): - tracks = SolutionTracks( - graph_2d_with_computed_features, - ndim=3, - time_attr="t", - tracklet_attr="track_id", - ) - controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() - - # delete unconnected node - node = 6 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert tracks.graph.number_of_edges() == num_edges - action.inverse() - - # delete end node - node = 5 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert not tracks.graph.has_edge(4, node) - action.inverse() - - # delete continuation node - node = 4 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert not tracks.graph.has_edge(3, node) - assert not tracks.graph.has_edge(node, 5) - assert tracks.graph.has_edge(3, 5) - assert tracks.get_track_id(5) == 3 - action.inverse() - - # delete div parent - node = 1 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert not tracks.graph.has_edge(node, 2) - assert not tracks.graph.has_edge(node, 3) - action.inverse() - - # delete div child - node = 3 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert tracks.get_track_id(2) == 1 # update track id for other child - - -def test__delete_nodes_with_seg(graph_2d_with_computed_features, segmentation_2d): - tracks = SolutionTracks( - graph_2d_with_computed_features, - segmentation=segmentation_2d, - time_attr="t", - tracklet_attr="track_id", - ) - controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() - - # delete unconnected node - node = 6 - track_id = 6 - time = 4 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert tracks.graph.number_of_edges() == num_edges - action.inverse() - - # delete end node - node = 5 - track_id = 3 - time = 4 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(4, node) - action.inverse() - - # delete continuation node - node = 4 - track_id = 3 - time = 2 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(3, node) - assert not tracks.graph.has_edge(node, 5) - assert tracks.graph.has_edge(3, 5) - assert tracks.get_track_id(5) == 3 - action.inverse() - - # delete div parent - node = 1 - track_id = 1 - time = 0 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(node, 2) - assert not tracks.graph.has_edge(node, 3) - action.inverse() - - # delete div child - node = 2 - track_id = 2 - time = 1 - action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert track_id not in np.unique(tracks.segmentation[time]) - assert tracks.get_track_id(3) == 1 # update track id for other child - assert tracks.get_track_id(5) == 1 # update track id for other child - - -def test__add_remove_edges_no_seg(graph_2d_with_computed_features): - tracks = SolutionTracks( - graph_2d_with_computed_features, - ndim=3, - time_attr="t", - tracklet_attr="track_id", - ) - controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() - - # delete continuation edge - edge = (3, 4) - track_id = 3 - controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) != track_id # relabeled the rest of the track - assert tracks.graph.number_of_edges() == num_edges - 1 - - # add back in continuation edge - controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) == track_id # track id was changed back - assert tracks.graph.number_of_edges() == num_edges - - # delete division edge - edge = (1, 3) - track_id = 3 - controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal - assert tracks.get_track_id(2) == 1 # but do relabel the sibling - assert tracks.graph.number_of_edges() == num_edges - 1 - - # add back in division edge - edge = (1, 3) - track_id = 3 - controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) - assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal - assert tracks.get_track_id(2) != 1 # give sibling new id again (not necessarily 2) - assert tracks.graph.number_of_edges() == num_edges diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 9ebba83b..b48e6666 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -6,10 +6,10 @@ from numpy.testing import assert_array_almost_equal from funtracks.data_model import Tracks -from funtracks.import_export.internal_format import ( +from funtracks.import_export._v1_format import ( + _save_v1_tracks, delete_tracks, - load_tracks, - save_tracks, + load_v1_tracks, ) @@ -24,9 +24,9 @@ def test_save_load( tmp_path, ): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) - save_tracks(tracks, tmp_path) + _save_v1_tracks(tracks, tmp_path) - loaded = load_tracks(tmp_path, solution=is_solution) + loaded = load_v1_tracks(tmp_path, solution=is_solution) assert loaded.ndim == tracks.ndim # Check feature keys and important properties match (allow tuple vs list diff) assert loaded.features.time_key == tracks.features.time_key @@ -86,7 +86,7 @@ def test_delete( ): tracks_path = tmp_path / "test_tracks" tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) - save_tracks(tracks, tracks_path) + _save_v1_tracks(tracks, tracks_path) delete_tracks(tracks_path) with pytest.raises(StopIteration): next(tmp_path.iterdir()) @@ -96,7 +96,7 @@ def test_delete( def test_load_without_features(tmp_path, graph_2d_with_computed_features): tracks = Tracks(graph_2d_with_computed_features, ndim=3) tracks_path = tmp_path / "test_tracks" - save_tracks(tracks, tracks_path) + _save_v1_tracks(tracks, tracks_path) attrs_path = tracks_path / "attrs.json" with open(attrs_path) as f: attrs = json.load(f) @@ -107,6 +107,6 @@ def test_load_without_features(tmp_path, graph_2d_with_computed_features): with open(attrs_path, "w") as f: json.dump(attrs, f) - imported_tracks = load_tracks(tracks_path) + imported_tracks = load_v1_tracks(tracks_path) assert imported_tracks.features.time_key == "time" assert imported_tracks.features.position_key == "pos"