Skip to content

Add support for point-windowed training (e.g. synapse assignment) #1137

@JoeStrout

Description

@JoeStrout

Problem

Assignment model training differs from typical binary classifier training in two ways:

  1. The model applies to a small 3D patch around specific points (given in the form of point or line annotations).
  2. The model has two output channels (presynaptic and postsynaptic cell mask) instead of just one.

The second point is easy to handle, but the first one requires some new code.

Proposal

Add support for this sort of "point-windowed" training, allowing the user to specify a "points" layer containing point or line annotations to window around, and then use a new sampler that efficiently samples those little windows around points within each bounding box.

Data Organization

Layer Group defines all data sources:

  • "img" - EM image data
  • "label0" - Presynaptic cell labels (ground truth)
  • "label1" - Postsynaptic cell labels (ground truth)
  • "points" - Precomputed annotation layer containing point or line annotations marking synapse locations

Collection contains bounding box annotations:

  • Each bbox annotation defines a volume of interest
  • Tagged with "train" or "val" for dataset split
  • Bboxes should be small enough to fit in memory (plus padding)

Training Workflow

For each bbox annotation in the collection, the code will:

  1. Load volume into memory

    • Read the entire bbox volume from all layers (img, label0, label1)
    • Inflate by half the window size on each side (to allow edge samples)
    • Cache this data in memory
  2. Read synapse locations

    • Read all point/line annotations from "points" layer that fall within bbox
    • For line annotations, compute midpoint
    • Filter to only points inside the (non-inflated) bbox
  3. Create point-based sampler

    • Create a dataset that samples centered windows around each point
    • Windows are extracted from the cached in-memory volume
    • Many samples per bbox, all from cached data
  4. Combine datasets

    • One dataset per bbox (as current build_collection_dataset does)
    • Join vertically to create final training dataset

Benefits

  • Each volume loaded once, many point samples extracted from it
  • Leverages existing collection bbox infrastructure
  • No need to modify db_annotations schema
  • Precomputed points layer treated like any other data layer
  • Efficient: amortizes I/O cost across many samples per volume
  • Natural train/val split via bbox tags

Implementation Plan

1. Extend build_collection_dataset

File: zetta_utils/training/datasets/collection_dataset.py

Changes:

  • Detect if layer group contains a "points" layer (or configurable name)
  • For each bbox annotation:
    • Read points from annotation layer within bbox bounds
    • Inflate bbox by chunk_size // 2 for padding
    • Pre-load inflated bbox data into memory (or create cached layer view)
    • Create custom indexer that:
      • Holds list of point locations
      • Maps integer index → centered window around that point
      • Reads from cached data

Questions to resolve:

  • Should cached data be held by the indexer or the layer?
  • Should we support configurable layer names (not hardcoded "points")?

2. Create CachedPointIndexer

File: zetta_utils/training/datasets/sample_indexers/cached_point_indexer.py (new)

Purpose: Indexer that samples windows around points from cached volume data

Interface:

@builder.register("CachedPointIndexer")
class CachedPointIndexer(SampleIndexer):
    points: list[Vec3D]  # Synapse locations
    chunk_size: Sequence[int]  # Window size
    resolution: Sequence[float]
    bbox: BBox3D  # Bbox bounds for validation

    def __len__(self) -> int:
        return len(self.points)

    def __call__(self, idx: int) -> VolumetricIndex:
        # Return VolumetricIndex for window centered at points[idx]

Key features:

  • Takes list of point locations (already filtered to bbox)
  • Computes centered window for each point
  • Returns VolumetricIndex like other indexers

3. Point Extraction Logic

Location: Helper function in collection_dataset.py or separate utility

Purpose: Read and filter points from annotation layer

Functionality:

  • Read precomputed annotations layer at bbox bounds
  • Extract point coordinates (handle both point and line types)
  • For line annotations, compute midpoint: (pointA + pointB) / 2 (or maybe have an extra parameter to control this)
  • Filter to points within bbox (accounting for window size)
  • Return list of Vec3D coordinates

Question:

  • Which precomputed annotation backend to use? (There are several now, including recent CloudVolume itself)

4. Data Caching Strategy

Option A: Pre-read into memory

  • Read inflated bbox data once at dataset creation
  • Store as numpy arrays in indexer
  • Fast but memory-intensive

Option B: Layer caching

  • Use existing layer caching mechanisms
  • Let layer backend handle caching
  • More flexible but may have overhead

Claude's Recommendation: Start with Option A for simplicity, optimize later if needed

5. New TwoChanSupervisedRegime

File: zetta_utils/internal/regimes/two_chan_supervised.py (new)

Purpose: PyTorch Lightning module for 2-output-channel model training

Key differences from BinarySupervisedRegime:

  • Model outputs 2 channels (pre and post in our current case)
  • Loss computed per-channel, then combined (sum or weighted average)
  • Target keys: "target_label0" and "target_label1" (or similar)
  • Class balancing per channel
  • Logging/visualization for both channels

Reuses from BinarySupervisedRegime:

  • Loss cropping/padding
  • Mask handling
  • Min non-zero fraction filtering
  • Training/validation structure
  • Optimizer configuration

6. Example Training Spec

File: specs/portal/train_assignment_3D.cue (new example)

"@type": "lightning_train"

regime: {
    "@type": "TwoChanSupervisedRegime"
    model: {
        "@type": "UNet3D"
        in_channels: 1
        out_channels: 2  // Pre and post channels
        // ... other model config
    }
    lr: 5e-5
    loss_crop_pad: [4, 4, 1]
}

train_dataloader: {
    "@type": "TorchDataLoader"
    batch_size: 4
    num_workers: 8
    dataset: {
        "@type": "build_collection_dataset"
        collection_name: "synapse_training"
        resolution: [4, 4, 40]
        chunk_size: [64, 64, 64]
        chunk_stride: [64, 64, 64]  // Not used for point sampling
        layer_rename_map: {
            "img": "img"
            "label0": "target_label0"   // Do we really need this?
            "label1": "target_label1"
        }
        tags: ["train"]
        points_layer_name: "points"  // New parameter
    }
}

val_dataloader: {
    // Similar but with tags: ["val"]
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions