From ed2f064d38d9446f669698ee6b786a9c8d962773 Mon Sep 17 00:00:00 2001
From: mattminder <myfiles@Mattus-MacBook-Pro.local>
Date: Wed, 13 May 2020 09:56:50 +0200
Subject: [PATCH] improved performance of cell merging

---
 unet/segment.py | 203 +++++++++++++++++++++++++++---------------------
 1 file changed, 115 insertions(+), 88 deletions(-)

diff --git a/unet/segment.py b/unet/segment.py
index 000fbf0..2ca2670 100644
--- a/unet/segment.py
+++ b/unet/segment.py
@@ -1,26 +1,22 @@
-"""
-Source of the code: https://github.com/mattminder/yeastSegHelpers
-"""
 from scipy import ndimage as ndi
 from skimage.feature import peak_local_max
-from skimage.morphology import watershed
+from skimage.morphology import watershed, dilation
 
-# from PIL import Image
 import numpy as np
-from skimage import data, util, filters, color
-import cv2
+# import cv2
 
-# get rid of this
-from skimage import io
 
-
-def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evaluate new borders
+def segment(th, pred, min_distance=10, topology=None): 
     """
     Performs watershed segmentation on thresholded image. Seeds have to
     have minimal distance of min_distance. topology defines the watershed
     topology to be used, default is the negative distance transform. Can
     either be an array with the same size af th, or a function that will
     be applied to the distance transform.
+    
+    After watershed, the borders found by watershed will be evaluated in terms
+    of their predicted value. If the borders are highly predicted to be cells,
+    the two cells are merged. 
     """
     dtr = ndi.morphology.distance_transform_edt(th)
     if topology is None:
@@ -31,94 +27,125 @@ def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evalu
 
     m = peak_local_max(-topology, min_distance, indices=False)
     m_lab = ndi.label(m)[0]
-#    print(m_lab.shape)
-#    print(type(m_lab[0,0]))
-#    print(type(th[0,0]))
-#    print(type(topology[0,0]))
-#    io.imsave("/home/sjrahi/Desktop/peaks.tif",               np.array(m_lab, np.uint32))
-#    io.imsave("/home/sjrahi/Desktop/image.tif",               np.array(th, np.float32))
-#    io.imsave("/home/sjrahi/Desktop/top.tif",                 np.array(topology, np.float32))
-#    io.imsave("/home/sjrahi/Desktop/peaks.tif",               m_lab)
-#    io.imsave("/home/sjrahi/Desktop/image.tif",               th)
-#    io.imsave("/home/sjrahi/Desktop/top.tif",                 topology)
     wsh = watershed(topology, m_lab, mask=th)
+    return cell_merge(wsh, pred)
+    
 
-#    print(m_lab.shape)
-#    print(wsh.shape)
-#    print(th.shape)
-#    print(topology.shape)
-
-    print(type(wsh[0,0]))
-#    io.imsave("/home/sjrahi/Desktop/wsh.tif",               np.array(wsh, np.uint32))
-#    io.imsave("/home/sjrahi/Desktop/wsh.tif",                 wsh)
-
-
-    wshshape=wsh.shape	# size of the watershed images, could probably get the sizes from the original input images but too lazy to figure out how
-    oriobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1]))	# the masks for the original cells are saved each separately here
-    dilobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1]))	# the masks for dilated cells are saved each separately here
-    objcoords=np.zeros((wsh.max()+1,4))			# coordinates of the bounding boxes for each dilated object saved here
-    wshclean=np.zeros((wshshape[0],wshshape[1]))
+def cell_merge(wsh, pred):
+    """
+    Procedure that merges cells if the border between them is predicted to be
+    cell pixels.
+    """
+    wshshape=wsh.shape
+    
+    # masks for the original cells
+    objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1]))	
+    
+    # masks for dilated cells
+    dil_objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1]))
+    
+    # bounding box coordinates	
+    obj_coords = np.zeros((wsh.max()+1,4))
     
-    kernel = np.ones((3,3), np.uint8)	# need kernel to dilate objects
-    for obj1 in range(0,wsh.max()):	# objects numbered starting with 0!! Careful!!
-        oriobjs[obj1,:,:] = np.uint8(np.multiply(wsh==(obj1+1),1))		# create a mask for each object (number: obj1+1)
-        dilobjs[obj1,:,:] = cv2.dilate(oriobjs[obj1,:,:], kernel, iterations=1)	# create a mask for each object (number: obj1+1) after dilation
-        objallpixels=np.where(dilobjs[obj1,:,:] != 0)
-        objcoords[obj1,0]=np.min(objallpixels[0])
-        objcoords[obj1,1]=np.max(objallpixels[0])
-        objcoords[obj1,2]=np.min(objallpixels[1])
-        objcoords[obj1,3]=np.max(objallpixels[1])
+    # cleaned watershed, output of function	
+    wshclean = np.zeros((wshshape[0],wshshape[1]))
+    
+    # kernel to dilate objects
+    kernel = np.ones((3,3), dtype=bool)	
+    
+    for obj1 in range(wsh.max()):
+        # create masks and dilated masks for obj
+        objs[obj1,:,:] = wsh==(obj1+1)	
+        dil_objs[obj1,:,:] = dilation(objs[obj1,:,:], kernel)	
+        
+        # bounding box
+        obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:])
+#        objallpixels = np.where(dilobjs[obj1,:,:] != 0)
+#        objcoords[obj1,0]=np.min(objallpixels[0])
+#        objcoords[obj1,1]=np.max(objallpixels[0])
+#        objcoords[obj1,2]=np.min(objallpixels[1])
+#        objcoords[obj1,3]=np.max(objallpixels[1])
     
     objcounter = 0	# will build up a new watershed mask, have to run a counter because some objects lost
-    for obj1 in range(0,wsh.max()):	#careful, object numbers -1 !
+    
+    for obj1 in range(wsh.max()):	
         print("Processing cell ",obj1+1," of ",wsh.max()," for oversegmentation.")
-        maskobj1 = dilobjs[obj1,:,:]
+        dil1 = dil_objs[obj1,:,:]
 
-        if np.sum(maskobj1) > 0:		#maskobj1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1
+        if np.sum(dil1) > 0:		#dil1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1
             objcounter = objcounter + 1
-            maskoriobj1 = oriobjs[obj1,:,:]
+            orig1 = objs[obj1,:,:]
 
             for obj2 in range(obj1+1,wsh.max()):
-                maskobj2 = dilobjs[obj2,:,:]
+                dil2 = dil_objs[obj2,:,:]
             
-                if (np.sum(maskobj2) > 0 and	#maskobj1 and 2 can be empty because joined with maskobj2's and then maskobj2's deleted (set to zero)
-                (((objcoords[obj1,0] - 2 < objcoords[obj2,0] and objcoords[obj1,1] + 2 > objcoords[obj2,0]) or		# do the bounding boxes overlap? plus/minus 2 pixels to allow for bad bounding box measurement
-                  (objcoords[obj2,0] - 2 < objcoords[obj1,0] and objcoords[obj2,1] + 2 > objcoords[obj1,0])) and
-                 ((objcoords[obj1,2] - 2 < objcoords[obj2,2] and objcoords[obj1,3] + 2 > objcoords[obj2,2]) or
-                  (objcoords[obj2,2] - 2 < objcoords[obj1,2] and objcoords[obj2,3] + 2 > objcoords[obj1,2])))):
-                    border = maskobj1 * maskobj2	#intersection of two masks constitutes a border
-                    # borderarea = np.sum(border)
-                    # borderpred = border * pred
-                    # borderheight = np.sum(borderpred)
-
-                    borderprednonzero = pred[np.nonzero(border)]		# all the prediction values inside the border area
-                    sortborderprednonzero = sorted(borderprednonzero)		# sort the values
-                    borderprednonzeroarea = len(borderprednonzero)		# how many values are there?
-                    quartborderarea = round(borderprednonzeroarea/4)		# take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter
-                    topborderpred = sortborderprednonzero[quartborderarea:]	# take top 3/4 of the predictions
-                    topborderheight = np.sum(topborderpred)			# sum over top 3/4 of the predictions
-                    topborderarea = len(topborderpred)				# area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again
-
-                    if topborderarea > 8:	# SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on.
-                        #print(obj1+1, obj2+1, topborderheight/topborderarea)
-                        if topborderheight/topborderarea > 0.99 :	# SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant.
-                            #print("--")
-                            #print(objcounter)
-                            #wsh=np.where(wsh==obj2+1, obj1+1, wsh)
-                            maskoriobj1       = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1))		#have to do boolean then integer just to do an 'or'
-                            dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1    > 0) | (maskobj2          > 0),1))		#have to do boolean then integer just to do an 'or'
-                            dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1]))
-                            objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0])
-                            objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1])
-                            objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2])
-                            objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3])
-                            print("Merged cell ",obj1+1," and ",obj2+1,".")
-
+                if (do_box_overlap(obj_coords[obj1,:], obj_coords[obj2,:])
+                    and np.sum(dil2) > 0):
+                    border = dil1 * dil2	
+                    
+                    border_pred = pred[border]
+                    
+                    # Border is too small to be considered
+                    if len(border_pred) < 32:
+                        continue
+                    
+                    # Sum of top 25% of predicted border values
+                    q75 = np.quantile(border_pred, .75)
+                    top_border_pred = border_pred[border_pred > q75]
+                    top_border_height = top_border_pred.sum()
+                    top_border_area = len(top_border_pred)
+                    
+#                    borderprednonzero = pred[np.nonzero(border)]		# all the prediction values inside the border area
+#                    sortborderprednonzero = sorted(borderprednonzero)		# sort the values
+#                    borderprednonzeroarea = len(borderprednonzero)		# how many values are there?
+#                    quartborderarea = round(borderprednonzeroarea/4)		# take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter
+#                    topborderpred = sortborderprednonzero[quartborderarea:]	# take top 3/4 of the predictions
+#                    topborderheight = np.sum(topborderpred)			# sum over top 3/4 of the predictions
+#                    topborderarea = len(topborderpred)				# area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again
+
+                    # merge cells
+                    if top_border_height / top_border_area > .99:
+                        orig1 = np.logical_or(orig1, objs[obj2,:,:])
+                        dil_objs[obj1,:,:] = np.logical_or(dil1, dil2)
+                        dil_objs[obj2,:,:] = np.zeros((wshshape[0], wshshape[1]))
+                        obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:])
+                        print("Merged cell ",obj1+1," and ",obj2+1,".")
+                        
+                        
+#                    if topborderarea > 8:	# SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on.
+#                        #print(obj1+1, obj2+1, topborderheight/topborderarea)
+#                        if topborderheight/topborderarea > 0.99 :	# SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant.
+#                            #print("--")
+#                            #print(objcounter)
+#                            #wsh=np.where(wsh==obj2+1, obj1+1, wsh)
+#                            maskoriobj1       = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1))		#have to do boolean then integer just to do an 'or'
+#                            dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1    > 0) | (maskobj2          > 0),1))		#have to do boolean then integer just to do an 'or'
+#                            dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1]))
+#                            objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0])
+#                            objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1])
+#                            objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2])
+#                            objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3])
+#                            print("Merged cell ",obj1+1," and ",obj2+1,".")
+
+            wshclean = wshclean + orig1*objcounter
+            
+    return wshclean
 
-            wshclean = wshclean + maskoriobj1*objcounter
-        #else:
-         #   display(obj1+1,' no longer there.')
 
-    return wshclean
-#    return wsh
+def do_box_overlap(coord1, coord2):
+    """Checks if boxes, determined by their coordinates, overlap. Safety
+    margin of 2 pixels"""
+    return (
+    (coord1[0] - 2 < coord2[0] and coord1[1] + 2 > coord2[0]
+        or coord2[0] - 2 < coord1[0] and coord2[1] + 2 > coord1[0]) 
+    and (coord1[2] - 2 < coord2[2] and coord1[3] + 2 > coord2[2]
+        or coord2[2] - 2 < coord1[2] and coord2[3] + 2 > coord1[2]))
 
+    
+def get_bounding_box(im):
+    """Returns bounding box of object in boolean image"""
+    coords = np.where(im)
+    
+    min0, min1 = coords.min(axis=0)
+    max0, max1 = coords.max(axis=0)
+    return np.array([min0, max0, min1, max1])
-- 
GitLab