TLDR: We automatically segmented brain data using deep learning. See our demo.
How does the brain work? Can we understand it as an electric circuit by analyzing how its neurons are connected? The Connectomics field of neurobiology searches for answers to these questions by taking high-resolution images of the brain, layer by layer. The result is a massive 3D image that shows the connections we care about. But for a meaningful analysis, we need to understand the images: We need to know, which voxels (pixels in 3D images) belong together. In other words, we need to segment the 3D image.
For example, consider this extract of the CREMI dataset. In the raw data, the structure is visible to the human eye. Segmentation makes this structure explicit by assigning voxels of the same object to the same ID, here shown as a color-coded segmentation layer:
The problem is that segmenting these datasets manually is tiring. Even worse, these types of datasets are orders of magnitudes too large for humans to annotate in any reasonable amount of time. Luckily, with the advent of artificial neural networks in image processing, there has recently been a lot of progress in automatic image segmentation.
Modeling the Segmentation Problem with Affinity Graphs
Modeling the segmentation task is not straightforward: In contrast to image classification, where we aim to make one prediction for the entire image, in segmentation, we want to group voxels together to define the segments. One approach is to consider the affinity graph of voxels: It has a node for each voxel and an edge between adjacent voxels. Each edge is labeled as 0 or 1: 1 if the two voxels belong to the same segment, 0 if they don’t. The following 2D-example shows the affinity graph for two segments:
We can convert such a graph to a valid segmentation by finding connected components (considering only edges labeled with 1). To obtain a labeled affinity graph, we need three predictions per voxel: One for the “next” edge in x, y, and z-direction. As a result, we need a neural network which has three output units for each voxel it makes predictions for.
Network Architecture: Fully Convolutional Networks
Image classification architectures, such as GoogLeNet, are unsuitable for this type of modeling because they are generally designed to compute a single output for the entire image. We would have to call such a network for every single voxel, providing an input image centered around the predicted voxel as input. Fully Convolutional Neural Networks, on the other hand, achieve this task much more efficiently: Because these networks do not contain fully connected layers, all computations are translation-invariant. As a result, these networks compute predictions for many voxels at once, sharing internal computation between them.
One of the most popular fully convolutional networks is U-Net. It bases its predictions on low-level features (“Is this voxel dark?”) as well as high-level features (“How does the surrounding area look like?”). Because of its superior performance in a number of 2D segmentation benchmarks, we used a 3D variant of this network as our machine learning component.
Our Image Segmentation Pipeline
We implemented an image segmentation pipeline which trains the U-Net architecture based on a relatively small ground-truth segmentation. The resulting model can then be used to segment the remaining dataset:
Note that there might be important preprocessing steps like 3D image alignment that impact performance.
In our demo, you can see that this approach produces reasonable results on the aforementioned CREMI dataset. It shows the predicted segmentation on an excerpt of the A+ dataset. As you can see, segments are mostly separated clearly in x and y-direction. In z-direction, segments are sometimes wrongly connected, leading to mergers. In general, getting the z-direction right is more challenging because of lower resolution and imperfect alignment.
We only use an excerpt (slices 2-30) of the A+ dataset because we realized that other sections have a lot more image noise than the training dataset. Of course, our model performs a lot worse in these sections, which leads to mergers across the entire dataset. We plan to deal with this issue using data augmentation: By artificially blurring and darkening the training data, we can teach our model to be more robust to varying image quality.
Outlook: Domain Adaptation
We implemented the system to segment neurons in brain image data, but really, the approach is likely to work for any image segmentation tasks, because nothing in the design is specific to the type of input dataset or output segmentation. I wouldn’t be surprised to see it working for pedestrian segmentation in video data 1 from self-driving cars. But deciphering the brain should be exciting enough for now.
For medical datasets, an interesting direction for future work would be the application of domain adaptation techniques in order to apply a model trained on one dataset to a different (e.g. new) dataset. Perhaps we could pre-train a general-purpose segmentation network that could generalize to new datasets using only a small amount of training data. It wouldn’t be free of challenges: We would have to deal with different dataset resolutions, image properties, types of segmentation, …
We’ll keep you posted.
1 Video data can be interpreted as 3-dimensional image data by interpreting the time dimension as the z-dimension.