-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Problem
Assignment model training differs from typical binary classifier training in two ways:
- The model applies to a small 3D patch around specific points (given in the form of point or line annotations).
- The model has two output channels (presynaptic and postsynaptic cell mask) instead of just one.
The second point is easy to handle, but the first one requires some new code.
Proposal
Add support for this sort of "point-windowed" training, allowing the user to specify a "points" layer containing point or line annotations to window around, and then use a new sampler that efficiently samples those little windows around points within each bounding box.
Data Organization
Layer Group defines all data sources:
"img"- EM image data"label0"- Presynaptic cell labels (ground truth)"label1"- Postsynaptic cell labels (ground truth)"points"- Precomputed annotation layer containing point or line annotations marking synapse locations
Collection contains bounding box annotations:
- Each bbox annotation defines a volume of interest
- Tagged with
"train"or"val"for dataset split - Bboxes should be small enough to fit in memory (plus padding)
Training Workflow
For each bbox annotation in the collection, the code will:
-
Load volume into memory
- Read the entire bbox volume from all layers (
img,label0,label1) - Inflate by half the window size on each side (to allow edge samples)
- Cache this data in memory
- Read the entire bbox volume from all layers (
-
Read synapse locations
- Read all point/line annotations from
"points"layer that fall within bbox - For line annotations, compute midpoint
- Filter to only points inside the (non-inflated) bbox
- Read all point/line annotations from
-
Create point-based sampler
- Create a dataset that samples centered windows around each point
- Windows are extracted from the cached in-memory volume
- Many samples per bbox, all from cached data
-
Combine datasets
- One dataset per bbox (as current
build_collection_datasetdoes) - Join vertically to create final training dataset
- One dataset per bbox (as current
Benefits
- Each volume loaded once, many point samples extracted from it
- Leverages existing collection bbox infrastructure
- No need to modify
db_annotationsschema - Precomputed points layer treated like any other data layer
- Efficient: amortizes I/O cost across many samples per volume
- Natural train/val split via bbox tags
Implementation Plan
1. Extend build_collection_dataset
File: zetta_utils/training/datasets/collection_dataset.py
Changes:
- Detect if layer group contains a
"points"layer (or configurable name) - For each bbox annotation:
- Read points from annotation layer within bbox bounds
- Inflate bbox by
chunk_size // 2for padding - Pre-load inflated bbox data into memory (or create cached layer view)
- Create custom indexer that:
- Holds list of point locations
- Maps integer index → centered window around that point
- Reads from cached data
Questions to resolve:
- Should cached data be held by the indexer or the layer?
- Should we support configurable layer names (not hardcoded
"points")?
2. Create CachedPointIndexer
File: zetta_utils/training/datasets/sample_indexers/cached_point_indexer.py (new)
Purpose: Indexer that samples windows around points from cached volume data
Interface:
@builder.register("CachedPointIndexer")
class CachedPointIndexer(SampleIndexer):
points: list[Vec3D] # Synapse locations
chunk_size: Sequence[int] # Window size
resolution: Sequence[float]
bbox: BBox3D # Bbox bounds for validation
def __len__(self) -> int:
return len(self.points)
def __call__(self, idx: int) -> VolumetricIndex:
# Return VolumetricIndex for window centered at points[idx]Key features:
- Takes list of point locations (already filtered to bbox)
- Computes centered window for each point
- Returns VolumetricIndex like other indexers
3. Point Extraction Logic
Location: Helper function in collection_dataset.py or separate utility
Purpose: Read and filter points from annotation layer
Functionality:
- Read precomputed annotations layer at bbox bounds
- Extract point coordinates (handle both point and line types)
- For line annotations, compute midpoint:
(pointA + pointB) / 2(or maybe have an extra parameter to control this) - Filter to points within bbox (accounting for window size)
- Return list of Vec3D coordinates
Question:
- Which precomputed annotation backend to use? (There are several now, including recent CloudVolume itself)
4. Data Caching Strategy
Option A: Pre-read into memory
- Read inflated bbox data once at dataset creation
- Store as numpy arrays in indexer
- Fast but memory-intensive
Option B: Layer caching
- Use existing layer caching mechanisms
- Let layer backend handle caching
- More flexible but may have overhead
Claude's Recommendation: Start with Option A for simplicity, optimize later if needed
5. New TwoChanSupervisedRegime
File: zetta_utils/internal/regimes/two_chan_supervised.py (new)
Purpose: PyTorch Lightning module for 2-output-channel model training
Key differences from BinarySupervisedRegime:
- Model outputs 2 channels (pre and post in our current case)
- Loss computed per-channel, then combined (sum or weighted average)
- Target keys:
"target_label0"and"target_label1"(or similar) - Class balancing per channel
- Logging/visualization for both channels
Reuses from BinarySupervisedRegime:
- Loss cropping/padding
- Mask handling
- Min non-zero fraction filtering
- Training/validation structure
- Optimizer configuration
6. Example Training Spec
File: specs/portal/train_assignment_3D.cue (new example)
"@type": "lightning_train"
regime: {
"@type": "TwoChanSupervisedRegime"
model: {
"@type": "UNet3D"
in_channels: 1
out_channels: 2 // Pre and post channels
// ... other model config
}
lr: 5e-5
loss_crop_pad: [4, 4, 1]
}
train_dataloader: {
"@type": "TorchDataLoader"
batch_size: 4
num_workers: 8
dataset: {
"@type": "build_collection_dataset"
collection_name: "synapse_training"
resolution: [4, 4, 40]
chunk_size: [64, 64, 64]
chunk_stride: [64, 64, 64] // Not used for point sampling
layer_rename_map: {
"img": "img"
"label0": "target_label0" // Do we really need this?
"label1": "target_label1"
}
tags: ["train"]
points_layer_name: "points" // New parameter
}
}
val_dataloader: {
// Similar but with tags: ["val"]
}