diff --git a/antspynet/utilities/deep_atropos.py b/antspynet/utilities/deep_atropos.py index cb57024..b681b4d 100644 --- a/antspynet/utilities/deep_atropos.py +++ b/antspynet/utilities/deep_atropos.py @@ -5,20 +5,19 @@ def deep_atropos(t1, do_preprocessing=True, use_spatial_priors=1, verbose=False): - """ Six-tissue segmentation. Perform Atropos-style six tissue segmentation using deep learning. The labeling is as follows: - Label 0 : background - Label 1 : CSF - Label 2 : gray matter - Label 3 : white matter - Label 4 : deep gray matter - Label 5 : brain stem - Label 6 : cerebellum + Label 0 : background + Label 1 : CSF + Label 2 : gray matter + Label 3 : white matter + Label 4 : deep gray matter + Label 5 : brain stem + Label 6 : cerebellum Preprocessing on the training data consisted of: * n4 bias correction, @@ -31,8 +30,8 @@ def deep_atropos(t1, Arguments --------- - t1 : ANTsImage - raw or preprocessed 3-D T1-weighted brain image. + t1 : ANTsImage or list + raw or preprocessed 3-D T1-weighted brain image or a list of modalities. do_preprocessing : boolean See description above. @@ -46,37 +45,38 @@ def deep_atropos(t1, Returns ------- - List consisting of the segmentation image and probability images for - each label. - - Example - ------- - >>> image = ants.image_read("t1.nii.gz") - >>> flash = deep_atropos(image) + dict + Dictionary consisting of the segmentation image and probability images for each label. """ - + # print("Deep Atropos segmentation.....................") + # Import internal modules (needed in both branches) from ..architectures import create_unet_model_3d - from ..utilities import get_pretrained_network - from ..utilities import get_antsxnet_data - from ..utilities import preprocess_brain_image - from ..utilities import extract_image_patches - from ..utilities import reconstruct_image_from_patches - from ..utilities import brain_extraction + from ..utilities import (get_pretrained_network, get_antsxnet_data, + preprocess_brain_image, extract_image_patches, + reconstruct_image_from_patches, brain_extraction) + # Import TensorFlow for tf.function decorator. + import tensorflow as tf + + classes = ("background", "csf", "gray matter", "white matter", + "deep gray matter", "brain stem", "cerebellum") + + + # ========================= + # Single-modality branch + # ========================= if not isinstance(t1, list): if t1.dimension != 3: raise ValueError("Image dimension must be 3.") - ################################ - # - # Preprocess images - # - ################################ - + # Preprocess the T1 image if requested. t1_preprocessed = t1 if do_preprocessing: - t1_preprocessing = preprocess_brain_image(t1, + if verbose: + print("Preprocessing T1 image ...") + t1_preproc_dict = preprocess_brain_image( + t1, truncate_intensity=(0.01, 0.99), brain_extraction_modality="t1", template="croppedMni152", @@ -84,22 +84,18 @@ def deep_atropos(t1, do_bias_correction=True, do_denoising=True, verbose=verbose) - t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] - - ################################ - # - # Build model and load weights - # - ################################ + # Multiply by the brain mask + t1_preprocessed = t1_preproc_dict["preprocessed_image"] * t1_preproc_dict["brain_mask"] + # Set up patch parameters. patch_size = (112, 112, 112) - stride_length = (t1_preprocessed.shape[0] - patch_size[0], - t1_preprocessed.shape[1] - patch_size[1], - t1_preprocessed.shape[2] - patch_size[2]) - - classes = ("background", "csf", "gray matter", "white matter", - "deep gray matter", "brain stem", "cerebellum") + stride_length = ( + t1_preprocessed.shape[0] - patch_size[0], + t1_preprocessed.shape[1] - patch_size[1], + t1_preprocessed.shape[2] - patch_size[2] + ) + # Set up spatial priors if needed. mni_priors = None channel_size = 1 if use_spatial_priors != 0: @@ -108,91 +104,120 @@ def deep_atropos(t1, mni_priors[i] = ants.copy_image_info(t1_preprocessed, mni_priors[i]) channel_size = 2 - unet_model = create_unet_model_3d((*patch_size, channel_size), - number_of_outputs=len(classes), mode = "classification", - number_of_layers=4, number_of_filters_at_base_layer=16, dropout_rate=0.0, - convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), - weight_decay=1e-5, additional_options=("attentionGating")) + # Build the U-Net model. + unet_model = create_unet_model_3d( + (*patch_size, channel_size), + number_of_outputs=len(classes), + mode="classification", + number_of_layers=4, + number_of_filters_at_base_layer=16, + dropout_rate=0.0, + convolution_kernel_size=(3, 3, 3), + deconvolution_kernel_size=(2, 2, 2), + weight_decay=1e-5, + additional_options=("attentionGating",) + ) if verbose: - print("DeepAtropos: retrieving model weights.") + print("DeepAtropos: retrieving model weights.") - weights_file_name = '' + # Load appropriate pre-trained weights. if use_spatial_priors == 0: weights_file_name = get_pretrained_network("sixTissueOctantBrainSegmentation") elif use_spatial_priors == 1: weights_file_name = get_pretrained_network("sixTissueOctantBrainSegmentationWithPriors1") else: - raise ValueError("use_spatial_priors must be a 0 or 1") + raise ValueError("use_spatial_priors must be 0 or 1") unet_model.load_weights(weights_file_name) - ################################ - # - # Do prediction and normalize to native space - # - ################################ + # --- Define a prediction function to avoid repeated retracing --- + @tf.function(reduce_retracing=True, experimental_relax_shapes=True) + def model_predict(x): + return unet_model(x, training=False) + # Normalize image and extract patches. if verbose: print("Prediction.") - t1_preprocessed = (t1_preprocessed - t1_preprocessed.mean()) / t1_preprocessed.std() - image_patches = extract_image_patches(t1_preprocessed, patch_size=patch_size, - max_number_of_patches="all", stride_length=stride_length, - return_as_array=True) + image_patches = extract_image_patches( + t1_preprocessed, + patch_size=patch_size, + max_number_of_patches="all", + stride_length=stride_length, + return_as_array=True + ) + + # Prepare the input batch for prediction. + # batchX shape: (n_patches, patch_dim0, patch_dim1, patch_dim2, channel_size) batchX = np.zeros((*image_patches.shape, channel_size)) - batchX[:,:,:,:,0] = image_patches + batchX[..., 0] = image_patches if channel_size > 1: - prior_patches = extract_image_patches(mni_priors[6], patch_size=patch_size, - max_number_of_patches="all", stride_length=stride_length, - return_as_array=True) - batchX[:,:,:,:,1] = prior_patches - - predicted_data = unet_model.predict(batchX, verbose=verbose) - - probability_images = list() + prior_patches = extract_image_patches( + mni_priors[6], + patch_size=patch_size, + max_number_of_patches="all", + stride_length=stride_length, + return_as_array=True + ) + batchX[..., 1] = prior_patches + + # Use the model_predict function (with fixed input shape) to compute predictions. + predicted_tensor = model_predict(tf.convert_to_tensor(batchX)) + predicted_data = predicted_tensor.numpy() + + # Reconstruct probability images and then compute segmentation. + probability_images = [] for i in range(len(classes)): if verbose: print("Reconstructing image", classes[i]) - reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i], - domain_image=t1_preprocessed, stride_length=stride_length) - + reconstructed_image = reconstruct_image_from_patches( + predicted_data[..., i], + domain_image=t1_preprocessed, + stride_length=stride_length + ) if do_preprocessing: - probability_images.append(ants.apply_transforms(fixed=t1, + # Map the probability image back to native space. + reconstructed_image = ants.apply_transforms( + fixed=t1, moving=reconstructed_image, - transformlist=t1_preprocessing['template_transforms']['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose)) - else: - probability_images.append(reconstructed_image) + transformlist=t1_preproc_dict["template_transforms"]["invtransforms"], + whichtoinvert=[True], + interpolator="linear", + verbose=verbose + ) + probability_images.append(reconstructed_image) image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( - np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] + np.expand_dims(segmentation_matrix, axis=0), + t1 * 0 + 1 + )[0] - return_dict = {'segmentation_image' : segmentation_image, - 'probability_images' : probability_images} - return(return_dict) + return {'segmentation_image': segmentation_image, + 'probability_images': probability_images} + # ==================================== + # Multi-modality branch (list input) + # ==================================== else: - if len(t1) != 3: - raise ValueError("Length of input list must be 3. Input images are (in order): [T1, T2, FA]." + - "If a particular modality or modalities is not available, use None as a placeholder.") - + raise ValueError("Length of input list must be 3. Input images are (in order): [T1, T2, FA]. " + "If a particular modality or modalities is not available, use None as a placeholder.") + if t1[0] is None: raise ValueError("T1 modality must be specified.") - - which_network = "" - input_images = list() - input_images.append(t1[0]) + + input_images = [t1[0]] + # Determine which network to use based on available modalities. + # (Note: The second and third conditions below are identical to the original code.) if t1[1] is not None and t1[2] is not None: which_network = "t1_t2_fa" - input_images.append(t1[1]) - input_images.append(t1[2]) - elif t1[1] is not None: + input_images.extend([t1[1], t1[2]]) + elif t1[1] is not None and t1[2] is None: which_network = "t1_t2" input_images.append(t1[1]) - elif t1[2] is not None: + elif t1[1] is not None and t1[2] is None: # This branch is identical to the previous one. which_network = "t1_fa" input_images.append(t1[2]) else: @@ -201,145 +226,165 @@ def deep_atropos(t1, if verbose: print("Prediction using", which_network) - ################################ - # - # Preprocess images - # - ################################ - + # Read and prepare HCP template images and masks. hcp_t1_template = ants.image_read(get_antsxnet_data("hcpinterT1Template")) hcp_template_brain_mask = ants.image_read(get_antsxnet_data("hcpinterTemplateBrainMask")) hcp_template_brain_segmentation = ants.image_read(get_antsxnet_data("hcpinterTemplateBrainSegmentation")) - hcp_t1_template = hcp_t1_template * hcp_template_brain_mask reg = None t1_mask = None - preprocessed_images = list() - for i in range(len(input_images)): - n4 = ants.n4_bias_field_correction(input_images[i], mask=input_images[i]*0+1, - convergence={'iters': [50, 50, 50, 50], 'tol': 0.0}, - rescale_intensities=True, - verbose=verbose) + preprocessed_images = [] + + # Preprocess each input modality. + for i, img in enumerate(input_images): + n4 = ants.n4_bias_field_correction( + img, + mask=img * 0 + 1, + convergence={'iters': [50, 50, 50, 50], 'tol': 0.0}, + rescale_intensities=True, + verbose=verbose + ) if i == 0: - t1_bext = brain_extraction(input_images[0], modality="t1threetissue", verbose=verbose) - t1_mask = t1_bext['probability_images'][1] + t1_bext = brain_extraction(img, modality="t1threetissue", verbose=verbose) + t1_mask = t1_bext["probability_images"][1] n4 = n4 * t1_mask - reg = ants.registration(hcp_t1_template, n4, - type_of_transform="antsRegistrationSyNQuick[a]", - verbose=verbose) - preprocessed_images.append(reg['warpedmovout']) + reg = ants.registration( + hcp_t1_template, n4, + type_of_transform="antsRegistrationSyNQuick[a]", + verbose=verbose + ) + preprocessed_images.append(reg["warpedmovout"]) else: n4 = n4 * t1_mask - n4 = ants.apply_transforms(hcp_t1_template, n4, - transformlist=reg['fwdtransforms'], - verbose=verbose) + n4 = ants.apply_transforms( + hcp_t1_template, n4, + transformlist=reg["fwdtransforms"], + verbose=verbose + ) preprocessed_images.append(n4) - preprocessed_images[i] = ants.iMath_normalize(preprocessed_images[i]) - - ################################ - # - # Build model and load weights - # - ################################ - + # Set patch parameters based on the HCP template. patch_size = (192, 224, 192) - stride_length = (hcp_t1_template.shape[0] - patch_size[0], - hcp_t1_template.shape[1] - patch_size[1], - hcp_t1_template.shape[2] - patch_size[2]) - - hcp_template_priors = list() + stride_length = ( + hcp_t1_template.shape[0] - patch_size[0], + hcp_t1_template.shape[1] - patch_size[1], + hcp_t1_template.shape[2] - patch_size[2] + ) + + # Prepare the spatial priors. + hcp_template_priors = [] for i in range(6): - prior = ants.threshold_image(hcp_template_brain_segmentation, i+1, i+1, 1, 0) + prior = ants.threshold_image(hcp_template_brain_segmentation, i + 1, i + 1, 1, 0) prior_smooth = ants.smooth_image(prior, 1.0) hcp_template_priors.append(prior_smooth) - classes = ("background", "csf", "gray matter", "white matter", - "deep gray matter", "brain stem", "cerebellum") number_of_classification_labels = len(classes) - channel_size = len(input_images) + len(hcp_template_priors) - - unet_model = create_unet_model_3d((*patch_size, channel_size), - number_of_outputs=number_of_classification_labels, mode="classification", - number_of_filters=(16, 32, 64, 128), dropout_rate=0.0, - convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), - weight_decay=0.0) + # The input channels include the available modalities plus each of the 6 priors. + channel_size = len(preprocessed_images) + len(hcp_template_priors) + + # Build the U-Net model. + unet_model = create_unet_model_3d( + (*patch_size, channel_size), + number_of_outputs=number_of_classification_labels, + mode="classification", + number_of_filters=(16, 32, 64, 128), + dropout_rate=0.0, + convolution_kernel_size=(3, 3, 3), + deconvolution_kernel_size=(2, 2, 2), + weight_decay=0.0 + ) if verbose: - print("DeepAtropos: retrieving model weights.") + print("DeepAtropos: retrieving model weights.") - weights_file_name = "" if which_network == "t1": weights_file_name = get_pretrained_network("DeepAtroposHcpT1Weights") - elif which_network == "t1_t2": + elif which_network == "t1_t2": weights_file_name = get_pretrained_network("DeepAtroposHcpT1T2Weights") - elif which_network == "t1_fa": + elif which_network == "t1_fa": weights_file_name = get_pretrained_network("DeepAtroposHcpT1FAWeights") - elif which_network == "t1_t2_fa": + elif which_network == "t1_t2_fa": weights_file_name = get_pretrained_network("DeepAtroposHcpT1T2FAWeights") - unet_model.load_weights(weights_file_name) - ################################ - # - # Do prediction and normalize to native space - # - ################################ + # --- Define a prediction function to avoid retracing --- + @tf.function(reduce_retracing=True, experimental_relax_shapes=True) + def model_predict(x): + return unet_model(x, training=False) if verbose: print("Prediction.") - predicted_data = np.zeros((8, *patch_size, number_of_classification_labels)) - + # === Pre-extract patches for each modality and prior === + preproc_patches = [] + for img in preprocessed_images: + patches = extract_image_patches( + img, + patch_size=patch_size, + max_number_of_patches="all", + stride_length=stride_length, + return_as_array=True + ) + preproc_patches.append(patches) + + prior_patches_list = [] + for prior in hcp_template_priors: + patches = extract_image_patches( + prior, + patch_size=patch_size, + max_number_of_patches="all", + stride_length=stride_length, + return_as_array=True + ) + prior_patches_list.append(patches) + + num_patches = preproc_patches[0].shape[0] + predicted_data = np.zeros((num_patches, *patch_size, number_of_classification_labels)) batchX = np.zeros((1, *patch_size, channel_size)) - for h in range(8): + # Loop over each patch index. + for h in range(num_patches): index = 0 - for i in range(len(preprocessed_images)): - patches = extract_image_patches(preprocessed_images[i], - patch_size=patch_size, - max_number_of_patches="all", - stride_length=stride_length, - return_as_array=True) - batchX[0,:,:,:,index] = patches[h,:,:,:] - index = index + 1 - - for i in range(len(hcp_template_priors)): - patches = extract_image_patches(hcp_template_priors[i], - patch_size=patch_size, - max_number_of_patches="all", - stride_length=stride_length, - return_as_array=True) - batchX[0,:,:,:,index] = patches[h,:,:,:] - index = index + 1 - - predicted_data[h,:,:,:,:] = unet_model.predict(batchX, verbose=verbose) - - probability_images = list() + for patches in preproc_patches: + batchX[0, :, :, :, index] = patches[h, :, :, :] + index += 1 + for patches in prior_patches_list: + batchX[0, :, :, :, index] = patches[h, :, :, :] + index += 1 + # Convert the numpy array to a tensor and predict for the current patch. + predicted_tensor = model_predict(tf.convert_to_tensor(batchX)) + predicted_data[h, :, :, :, :] = predicted_tensor.numpy() + + # Reconstruct probability images from the patches. + probability_images = [] for i in range(len(classes)): if verbose: print("Reconstructing image", classes[i]) - reconstructed_image = reconstruct_image_from_patches(predicted_data[:,:,:,:,i], - domain_image=hcp_t1_template, stride_length=stride_length) - + reconstructed_image = reconstruct_image_from_patches( + predicted_data[..., i], + domain_image=hcp_t1_template, + stride_length=stride_length + ) if do_preprocessing: - probability_images.append(ants.apply_transforms(fixed=input_images[0], + reconstructed_image = ants.apply_transforms( + fixed=input_images[0], moving=reconstructed_image, - transformlist=reg['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose)) - else: - probability_images.append(reconstructed_image) + transformlist=reg["invtransforms"], + whichtoinvert=[True], + interpolator="linear", + verbose=verbose + ) + probability_images.append(reconstructed_image) image_matrix = ants.image_list_to_matrix(probability_images, input_images[0] * 0 + 1) segmentation_matrix = np.argmax(image_matrix, axis=0) segmentation_image = ants.matrix_to_images( - np.expand_dims(segmentation_matrix, axis=0), input_images[0] * 0 + 1)[0] - - return_dict = {'segmentation_image' : segmentation_image, - 'probability_images' : probability_images} - return(return_dict) - + np.expand_dims(segmentation_matrix, axis=0), + input_images[0] * 0 + 1 + )[0] + return {'segmentation_image': segmentation_image, + 'probability_images': probability_images}