Solution for missing class label in multi-class segmentation
I am performing segmentation of 8 tissue classes. After certain epochs, 7 tissue classes have more or less expected results. But out of eight class 1 is always not being predicted at all. For example, look at the image below:

This is the network output I am plotting from the validation set. In almost all the samples, the eyeball label is missing and has 0 (background) voxels. But in train input images eyeball intensities are in the range of 20-30 and the background is ~0.

I am using PyTorch and thinking about gradient accumulation to ensure the network sees all the different tissue classes before backpropagation.


@Ivan Here is the voxel count of the data per class in the X-axis. I ignored the background voxels(which is 0). Also, the eyeballs class is 7. Not sure what other statistics you wanted. Let me know.
class 7 and 8 have almost the same number of voxels but training misses class 7 completely.

