Skip to content
Snippets Groups Projects
Commit ed2f064d authored by mattminder's avatar mattminder
Browse files

improved performance of cell merging

parent 419c4433
No related branches found
No related tags found
No related merge requests found
"""
Source of the code: https://github.com/mattminder/yeastSegHelpers
"""
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.morphology import watershed
from skimage.morphology import watershed, dilation
# from PIL import Image
import numpy as np
from skimage import data, util, filters, color
import cv2
# import cv2
# get rid of this
from skimage import io
def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evaluate new borders
def segment(th, pred, min_distance=10, topology=None):
"""
Performs watershed segmentation on thresholded image. Seeds have to
have minimal distance of min_distance. topology defines the watershed
topology to be used, default is the negative distance transform. Can
either be an array with the same size af th, or a function that will
be applied to the distance transform.
After watershed, the borders found by watershed will be evaluated in terms
of their predicted value. If the borders are highly predicted to be cells,
the two cells are merged.
"""
dtr = ndi.morphology.distance_transform_edt(th)
if topology is None:
......@@ -31,94 +27,125 @@ def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evalu
m = peak_local_max(-topology, min_distance, indices=False)
m_lab = ndi.label(m)[0]
# print(m_lab.shape)
# print(type(m_lab[0,0]))
# print(type(th[0,0]))
# print(type(topology[0,0]))
# io.imsave("/home/sjrahi/Desktop/peaks.tif", np.array(m_lab, np.uint32))
# io.imsave("/home/sjrahi/Desktop/image.tif", np.array(th, np.float32))
# io.imsave("/home/sjrahi/Desktop/top.tif", np.array(topology, np.float32))
# io.imsave("/home/sjrahi/Desktop/peaks.tif", m_lab)
# io.imsave("/home/sjrahi/Desktop/image.tif", th)
# io.imsave("/home/sjrahi/Desktop/top.tif", topology)
wsh = watershed(topology, m_lab, mask=th)
return cell_merge(wsh, pred)
# print(m_lab.shape)
# print(wsh.shape)
# print(th.shape)
# print(topology.shape)
print(type(wsh[0,0]))
# io.imsave("/home/sjrahi/Desktop/wsh.tif", np.array(wsh, np.uint32))
# io.imsave("/home/sjrahi/Desktop/wsh.tif", wsh)
wshshape=wsh.shape # size of the watershed images, could probably get the sizes from the original input images but too lazy to figure out how
oriobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) # the masks for the original cells are saved each separately here
dilobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) # the masks for dilated cells are saved each separately here
objcoords=np.zeros((wsh.max()+1,4)) # coordinates of the bounding boxes for each dilated object saved here
wshclean=np.zeros((wshshape[0],wshshape[1]))
def cell_merge(wsh, pred):
"""
Procedure that merges cells if the border between them is predicted to be
cell pixels.
"""
wshshape=wsh.shape
# masks for the original cells
objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1]))
# masks for dilated cells
dil_objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1]))
# bounding box coordinates
obj_coords = np.zeros((wsh.max()+1,4))
kernel = np.ones((3,3), np.uint8) # need kernel to dilate objects
for obj1 in range(0,wsh.max()): # objects numbered starting with 0!! Careful!!
oriobjs[obj1,:,:] = np.uint8(np.multiply(wsh==(obj1+1),1)) # create a mask for each object (number: obj1+1)
dilobjs[obj1,:,:] = cv2.dilate(oriobjs[obj1,:,:], kernel, iterations=1) # create a mask for each object (number: obj1+1) after dilation
objallpixels=np.where(dilobjs[obj1,:,:] != 0)
objcoords[obj1,0]=np.min(objallpixels[0])
objcoords[obj1,1]=np.max(objallpixels[0])
objcoords[obj1,2]=np.min(objallpixels[1])
objcoords[obj1,3]=np.max(objallpixels[1])
# cleaned watershed, output of function
wshclean = np.zeros((wshshape[0],wshshape[1]))
# kernel to dilate objects
kernel = np.ones((3,3), dtype=bool)
for obj1 in range(wsh.max()):
# create masks and dilated masks for obj
objs[obj1,:,:] = wsh==(obj1+1)
dil_objs[obj1,:,:] = dilation(objs[obj1,:,:], kernel)
# bounding box
obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:])
# objallpixels = np.where(dilobjs[obj1,:,:] != 0)
# objcoords[obj1,0]=np.min(objallpixels[0])
# objcoords[obj1,1]=np.max(objallpixels[0])
# objcoords[obj1,2]=np.min(objallpixels[1])
# objcoords[obj1,3]=np.max(objallpixels[1])
objcounter = 0 # will build up a new watershed mask, have to run a counter because some objects lost
for obj1 in range(0,wsh.max()): #careful, object numbers -1 !
for obj1 in range(wsh.max()):
print("Processing cell ",obj1+1," of ",wsh.max()," for oversegmentation.")
maskobj1 = dilobjs[obj1,:,:]
dil1 = dil_objs[obj1,:,:]
if np.sum(maskobj1) > 0: #maskobj1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1
if np.sum(dil1) > 0: #dil1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1
objcounter = objcounter + 1
maskoriobj1 = oriobjs[obj1,:,:]
orig1 = objs[obj1,:,:]
for obj2 in range(obj1+1,wsh.max()):
maskobj2 = dilobjs[obj2,:,:]
dil2 = dil_objs[obj2,:,:]
if (np.sum(maskobj2) > 0 and #maskobj1 and 2 can be empty because joined with maskobj2's and then maskobj2's deleted (set to zero)
(((objcoords[obj1,0] - 2 < objcoords[obj2,0] and objcoords[obj1,1] + 2 > objcoords[obj2,0]) or # do the bounding boxes overlap? plus/minus 2 pixels to allow for bad bounding box measurement
(objcoords[obj2,0] - 2 < objcoords[obj1,0] and objcoords[obj2,1] + 2 > objcoords[obj1,0])) and
((objcoords[obj1,2] - 2 < objcoords[obj2,2] and objcoords[obj1,3] + 2 > objcoords[obj2,2]) or
(objcoords[obj2,2] - 2 < objcoords[obj1,2] and objcoords[obj2,3] + 2 > objcoords[obj1,2])))):
border = maskobj1 * maskobj2 #intersection of two masks constitutes a border
# borderarea = np.sum(border)
# borderpred = border * pred
# borderheight = np.sum(borderpred)
borderprednonzero = pred[np.nonzero(border)] # all the prediction values inside the border area
sortborderprednonzero = sorted(borderprednonzero) # sort the values
borderprednonzeroarea = len(borderprednonzero) # how many values are there?
quartborderarea = round(borderprednonzeroarea/4) # take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter
topborderpred = sortborderprednonzero[quartborderarea:] # take top 3/4 of the predictions
topborderheight = np.sum(topborderpred) # sum over top 3/4 of the predictions
topborderarea = len(topborderpred) # area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again
if topborderarea > 8: # SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on.
#print(obj1+1, obj2+1, topborderheight/topborderarea)
if topborderheight/topborderarea > 0.99 : # SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant.
#print("--")
#print(objcounter)
#wsh=np.where(wsh==obj2+1, obj1+1, wsh)
maskoriobj1 = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1)) #have to do boolean then integer just to do an 'or'
dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1 > 0) | (maskobj2 > 0),1)) #have to do boolean then integer just to do an 'or'
dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1]))
objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0])
objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1])
objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2])
objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3])
print("Merged cell ",obj1+1," and ",obj2+1,".")
if (do_box_overlap(obj_coords[obj1,:], obj_coords[obj2,:])
and np.sum(dil2) > 0):
border = dil1 * dil2
border_pred = pred[border]
# Border is too small to be considered
if len(border_pred) < 32:
continue
# Sum of top 25% of predicted border values
q75 = np.quantile(border_pred, .75)
top_border_pred = border_pred[border_pred > q75]
top_border_height = top_border_pred.sum()
top_border_area = len(top_border_pred)
# borderprednonzero = pred[np.nonzero(border)] # all the prediction values inside the border area
# sortborderprednonzero = sorted(borderprednonzero) # sort the values
# borderprednonzeroarea = len(borderprednonzero) # how many values are there?
# quartborderarea = round(borderprednonzeroarea/4) # take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter
# topborderpred = sortborderprednonzero[quartborderarea:] # take top 3/4 of the predictions
# topborderheight = np.sum(topborderpred) # sum over top 3/4 of the predictions
# topborderarea = len(topborderpred) # area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again
# merge cells
if top_border_height / top_border_area > .99:
orig1 = np.logical_or(orig1, objs[obj2,:,:])
dil_objs[obj1,:,:] = np.logical_or(dil1, dil2)
dil_objs[obj2,:,:] = np.zeros((wshshape[0], wshshape[1]))
obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:])
print("Merged cell ",obj1+1," and ",obj2+1,".")
# if topborderarea > 8: # SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on.
# #print(obj1+1, obj2+1, topborderheight/topborderarea)
# if topborderheight/topborderarea > 0.99 : # SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant.
# #print("--")
# #print(objcounter)
# #wsh=np.where(wsh==obj2+1, obj1+1, wsh)
# maskoriobj1 = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1)) #have to do boolean then integer just to do an 'or'
# dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1 > 0) | (maskobj2 > 0),1)) #have to do boolean then integer just to do an 'or'
# dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1]))
# objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0])
# objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1])
# objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2])
# objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3])
# print("Merged cell ",obj1+1," and ",obj2+1,".")
wshclean = wshclean + orig1*objcounter
return wshclean
wshclean = wshclean + maskoriobj1*objcounter
#else:
# display(obj1+1,' no longer there.')
return wshclean
# return wsh
def do_box_overlap(coord1, coord2):
"""Checks if boxes, determined by their coordinates, overlap. Safety
margin of 2 pixels"""
return (
(coord1[0] - 2 < coord2[0] and coord1[1] + 2 > coord2[0]
or coord2[0] - 2 < coord1[0] and coord2[1] + 2 > coord1[0])
and (coord1[2] - 2 < coord2[2] and coord1[3] + 2 > coord2[2]
or coord2[2] - 2 < coord1[2] and coord2[3] + 2 > coord1[2]))
def get_bounding_box(im):
"""Returns bounding box of object in boolean image"""
coords = np.where(im)
min0, min1 = coords.min(axis=0)
max0, max1 = coords.max(axis=0)
return np.array([min0, max0, min1, max1])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment