diff --git a/unet/segment.py b/unet/segment.py index 000fbf0c2f8b152a6ca1628da308fccc3c8696ea..2ca26702520d45e49c8565dcd6677af8bc2f7943 100644 --- a/unet/segment.py +++ b/unet/segment.py @@ -1,26 +1,22 @@ -""" -Source of the code: https://github.com/mattminder/yeastSegHelpers -""" from scipy import ndimage as ndi from skimage.feature import peak_local_max -from skimage.morphology import watershed +from skimage.morphology import watershed, dilation -# from PIL import Image import numpy as np -from skimage import data, util, filters, color -import cv2 +# import cv2 -# get rid of this -from skimage import io - -def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evaluate new borders +def segment(th, pred, min_distance=10, topology=None): """ Performs watershed segmentation on thresholded image. Seeds have to have minimal distance of min_distance. topology defines the watershed topology to be used, default is the negative distance transform. Can either be an array with the same size af th, or a function that will be applied to the distance transform. + + After watershed, the borders found by watershed will be evaluated in terms + of their predicted value. If the borders are highly predicted to be cells, + the two cells are merged. """ dtr = ndi.morphology.distance_transform_edt(th) if topology is None: @@ -31,94 +27,125 @@ def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evalu m = peak_local_max(-topology, min_distance, indices=False) m_lab = ndi.label(m)[0] -# print(m_lab.shape) -# print(type(m_lab[0,0])) -# print(type(th[0,0])) -# print(type(topology[0,0])) -# io.imsave("/home/sjrahi/Desktop/peaks.tif", np.array(m_lab, np.uint32)) -# io.imsave("/home/sjrahi/Desktop/image.tif", np.array(th, np.float32)) -# io.imsave("/home/sjrahi/Desktop/top.tif", np.array(topology, np.float32)) -# io.imsave("/home/sjrahi/Desktop/peaks.tif", m_lab) -# io.imsave("/home/sjrahi/Desktop/image.tif", th) -# io.imsave("/home/sjrahi/Desktop/top.tif", topology) wsh = watershed(topology, m_lab, mask=th) + return cell_merge(wsh, pred) + -# print(m_lab.shape) -# print(wsh.shape) -# print(th.shape) -# print(topology.shape) - - print(type(wsh[0,0])) -# io.imsave("/home/sjrahi/Desktop/wsh.tif", np.array(wsh, np.uint32)) -# io.imsave("/home/sjrahi/Desktop/wsh.tif", wsh) - - - wshshape=wsh.shape # size of the watershed images, could probably get the sizes from the original input images but too lazy to figure out how - oriobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) # the masks for the original cells are saved each separately here - dilobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) # the masks for dilated cells are saved each separately here - objcoords=np.zeros((wsh.max()+1,4)) # coordinates of the bounding boxes for each dilated object saved here - wshclean=np.zeros((wshshape[0],wshshape[1])) +def cell_merge(wsh, pred): + """ + Procedure that merges cells if the border between them is predicted to be + cell pixels. + """ + wshshape=wsh.shape + + # masks for the original cells + objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) + + # masks for dilated cells + dil_objs = np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) + + # bounding box coordinates + obj_coords = np.zeros((wsh.max()+1,4)) - kernel = np.ones((3,3), np.uint8) # need kernel to dilate objects - for obj1 in range(0,wsh.max()): # objects numbered starting with 0!! Careful!! - oriobjs[obj1,:,:] = np.uint8(np.multiply(wsh==(obj1+1),1)) # create a mask for each object (number: obj1+1) - dilobjs[obj1,:,:] = cv2.dilate(oriobjs[obj1,:,:], kernel, iterations=1) # create a mask for each object (number: obj1+1) after dilation - objallpixels=np.where(dilobjs[obj1,:,:] != 0) - objcoords[obj1,0]=np.min(objallpixels[0]) - objcoords[obj1,1]=np.max(objallpixels[0]) - objcoords[obj1,2]=np.min(objallpixels[1]) - objcoords[obj1,3]=np.max(objallpixels[1]) + # cleaned watershed, output of function + wshclean = np.zeros((wshshape[0],wshshape[1])) + + # kernel to dilate objects + kernel = np.ones((3,3), dtype=bool) + + for obj1 in range(wsh.max()): + # create masks and dilated masks for obj + objs[obj1,:,:] = wsh==(obj1+1) + dil_objs[obj1,:,:] = dilation(objs[obj1,:,:], kernel) + + # bounding box + obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:]) +# objallpixels = np.where(dilobjs[obj1,:,:] != 0) +# objcoords[obj1,0]=np.min(objallpixels[0]) +# objcoords[obj1,1]=np.max(objallpixels[0]) +# objcoords[obj1,2]=np.min(objallpixels[1]) +# objcoords[obj1,3]=np.max(objallpixels[1]) objcounter = 0 # will build up a new watershed mask, have to run a counter because some objects lost - for obj1 in range(0,wsh.max()): #careful, object numbers -1 ! + + for obj1 in range(wsh.max()): print("Processing cell ",obj1+1," of ",wsh.max()," for oversegmentation.") - maskobj1 = dilobjs[obj1,:,:] + dil1 = dil_objs[obj1,:,:] - if np.sum(maskobj1) > 0: #maskobj1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1 + if np.sum(dil1) > 0: #dil1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1 objcounter = objcounter + 1 - maskoriobj1 = oriobjs[obj1,:,:] + orig1 = objs[obj1,:,:] for obj2 in range(obj1+1,wsh.max()): - maskobj2 = dilobjs[obj2,:,:] + dil2 = dil_objs[obj2,:,:] - if (np.sum(maskobj2) > 0 and #maskobj1 and 2 can be empty because joined with maskobj2's and then maskobj2's deleted (set to zero) - (((objcoords[obj1,0] - 2 < objcoords[obj2,0] and objcoords[obj1,1] + 2 > objcoords[obj2,0]) or # do the bounding boxes overlap? plus/minus 2 pixels to allow for bad bounding box measurement - (objcoords[obj2,0] - 2 < objcoords[obj1,0] and objcoords[obj2,1] + 2 > objcoords[obj1,0])) and - ((objcoords[obj1,2] - 2 < objcoords[obj2,2] and objcoords[obj1,3] + 2 > objcoords[obj2,2]) or - (objcoords[obj2,2] - 2 < objcoords[obj1,2] and objcoords[obj2,3] + 2 > objcoords[obj1,2])))): - border = maskobj1 * maskobj2 #intersection of two masks constitutes a border - # borderarea = np.sum(border) - # borderpred = border * pred - # borderheight = np.sum(borderpred) - - borderprednonzero = pred[np.nonzero(border)] # all the prediction values inside the border area - sortborderprednonzero = sorted(borderprednonzero) # sort the values - borderprednonzeroarea = len(borderprednonzero) # how many values are there? - quartborderarea = round(borderprednonzeroarea/4) # take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter - topborderpred = sortborderprednonzero[quartborderarea:] # take top 3/4 of the predictions - topborderheight = np.sum(topborderpred) # sum over top 3/4 of the predictions - topborderarea = len(topborderpred) # area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again - - if topborderarea > 8: # SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on. - #print(obj1+1, obj2+1, topborderheight/topborderarea) - if topborderheight/topborderarea > 0.99 : # SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant. - #print("--") - #print(objcounter) - #wsh=np.where(wsh==obj2+1, obj1+1, wsh) - maskoriobj1 = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1)) #have to do boolean then integer just to do an 'or' - dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1 > 0) | (maskobj2 > 0),1)) #have to do boolean then integer just to do an 'or' - dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1])) - objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0]) - objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1]) - objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2]) - objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3]) - print("Merged cell ",obj1+1," and ",obj2+1,".") - + if (do_box_overlap(obj_coords[obj1,:], obj_coords[obj2,:]) + and np.sum(dil2) > 0): + border = dil1 * dil2 + + border_pred = pred[border] + + # Border is too small to be considered + if len(border_pred) < 32: + continue + + # Sum of top 25% of predicted border values + q75 = np.quantile(border_pred, .75) + top_border_pred = border_pred[border_pred > q75] + top_border_height = top_border_pred.sum() + top_border_area = len(top_border_pred) + +# borderprednonzero = pred[np.nonzero(border)] # all the prediction values inside the border area +# sortborderprednonzero = sorted(borderprednonzero) # sort the values +# borderprednonzeroarea = len(borderprednonzero) # how many values are there? +# quartborderarea = round(borderprednonzeroarea/4) # take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter +# topborderpred = sortborderprednonzero[quartborderarea:] # take top 3/4 of the predictions +# topborderheight = np.sum(topborderpred) # sum over top 3/4 of the predictions +# topborderarea = len(topborderpred) # area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again + + # merge cells + if top_border_height / top_border_area > .99: + orig1 = np.logical_or(orig1, objs[obj2,:,:]) + dil_objs[obj1,:,:] = np.logical_or(dil1, dil2) + dil_objs[obj2,:,:] = np.zeros((wshshape[0], wshshape[1])) + obj_coords[obj1,:] = get_bounding_box(dil_objs[obj1,:,:]) + print("Merged cell ",obj1+1," and ",obj2+1,".") + + +# if topborderarea > 8: # SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on. +# #print(obj1+1, obj2+1, topborderheight/topborderarea) +# if topborderheight/topborderarea > 0.99 : # SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant. +# #print("--") +# #print(objcounter) +# #wsh=np.where(wsh==obj2+1, obj1+1, wsh) +# maskoriobj1 = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1)) #have to do boolean then integer just to do an 'or' +# dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1 > 0) | (maskobj2 > 0),1)) #have to do boolean then integer just to do an 'or' +# dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1])) +# objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0]) +# objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1]) +# objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2]) +# objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3]) +# print("Merged cell ",obj1+1," and ",obj2+1,".") + + wshclean = wshclean + orig1*objcounter + + return wshclean - wshclean = wshclean + maskoriobj1*objcounter - #else: - # display(obj1+1,' no longer there.') - return wshclean -# return wsh +def do_box_overlap(coord1, coord2): + """Checks if boxes, determined by their coordinates, overlap. Safety + margin of 2 pixels""" + return ( + (coord1[0] - 2 < coord2[0] and coord1[1] + 2 > coord2[0] + or coord2[0] - 2 < coord1[0] and coord2[1] + 2 > coord1[0]) + and (coord1[2] - 2 < coord2[2] and coord1[3] + 2 > coord2[2] + or coord2[2] - 2 < coord1[2] and coord2[3] + 2 > coord1[2])) + +def get_bounding_box(im): + """Returns bounding box of object in boolean image""" + coords = np.where(im) + + min0, min1 = coords.min(axis=0) + max0, max1 = coords.max(axis=0) + return np.array([min0, max0, min1, max1])