ITK segmentation preview

[1]:
import os
import glob
from skimage.measure import label
from IPython.display import HTML
from ipywidgets import interact, interactive, fixed, interact_manual, widgets

import numpy as np
import k3d
import SimpleITK as sitk
import vtk

from vtk.util import numpy_support

plot = k3d.plot(camera_auto_fit=False)
plot.display()
[2]:
def load_geometry(filename, arrayName=0):
    reader = vtk.vtkXMLImageDataReader()
    reader.SetFileName(filename)
    reader.Update()
    imageVTK = reader.GetOutput()
    x, y, z = imageVTK.GetDimensions()

    return (imageVTK, numpy_support.vtk_to_numpy(
        imageVTK.GetPointData().GetArray(arrayName)
    ).reshape(-1, y, x))

imageVTK, image = load_geometry(
    './segmentation/data/scan1.vti', 'data')

image = image.astype(np.float32)
[3]:
_, GT1 = load_geometry('./segmentation/GT/scan1.vti', 'lca')
_, GT2 = load_geometry('./segmentation/GT/scan1.vti', 'rca')
GT = np.bitwise_or(GT1, GT2)
[4]:
seed = (230, 291, 227)
image[seed[2]-50:seed[2]+50, seed[1]-50:seed[1]+50, seed[0] - 3:seed[0]] = 0

itkScan = sitk.GetImageFromArray(image)
itkScan.SetSpacing(imageVTK.GetSpacing())

seg_img = sitk.ConfidenceConnected(
    itkScan, [(seed[0]-8, seed[1], seed[2])],
    numberOfIterations=2, multiplier=3.2)

seg = sitk.GetArrayFromImage(seg_img)

dilate = sitk.BinaryDilateImageFilter()
dilate.SetKernelRadius((3,3,3))
dilate.SetForegroundValue(1)
seg_img = dilate.Execute(seg_img)
seg2 = sitk.GetArrayFromImage(seg_img)
image[seg2==1] = -1024
itkScan = sitk.GetImageFromArray(image)

seg_img = sitk.ConfidenceConnected(
    itkScan, [(seed[0]+2, seed[1], seed[2])],
    numberOfIterations=3, multiplier=3.290)

seg = sitk.GetArrayFromImage(seg_img)
[5]:
color_map = (0xff0000, 0x00ff00, 0xffff00)
voxels_classic = seg + GT * 2
plot += k3d.voxels(voxels_classic.astype(np.uint8),
                   color_map, outlines=False,
                   opacity=0.5,
                   bounds=imageVTK.GetBounds())
[6]:
plot += k3d.text2d('True Positive',
                   position = [0.05, 0.0],
                   reference_point='lt',
                   size=2.0,
                   color=0xaaaa00)

plot += k3d.text2d('False Positive',
                   position = [0.05, 0.12],
                   reference_point='lt',
                   size=2.0,
                   color=0xff0000)

plot += k3d.text2d('False Negative',
                   position = [0.05, 0.24],
                   reference_point='lt',
                   size=2.0,
                   color=0x00ff00)
[ ]: