Skip to content
Snippets Groups Projects
Unverified Commit 4b9424ec authored by mattminder's avatar mattminder Committed by GitHub
Browse files

Merge pull request #3 from lpbsscientist/dev

Dev
parents 8c834210 043cd41d
No related branches found
No related tags found
No related merge requests found
...@@ -757,6 +757,10 @@ class App(QMainWindow): ...@@ -757,6 +757,10 @@ class App(QMainWindow):
inside of self.PredThreshSeg and it does the prediction of the neural inside of self.PredThreshSeg and it does the prediction of the neural
network, thresholds this prediction and then segments it. network, thresholds this prediction and then segments it.
""" """
self.WriteStatusBar('Running the neural network...')
self.Disable(self.button_cnn)
# creates a dialog window from the LaunchBatchPrediction.py file # creates a dialog window from the LaunchBatchPrediction.py file
dlg = lbp.CustomDialog(self) dlg = lbp.CustomDialog(self)
...@@ -814,10 +818,14 @@ class App(QMainWindow): ...@@ -814,10 +818,14 @@ class App(QMainWindow):
temp_mask = self.reader.LoadSeg(t, dlg.listfov.row(item)) temp_mask = self.reader.LoadSeg(t, dlg.listfov.row(item))
self.reader.SaveMask(t,dlg.listfov.row(item), temp_mask) self.reader.SaveMask(t,dlg.listfov.row(item), temp_mask)
self.m.UpdatePlots() self.ReloadThreeMasks()
self.ClearStatusBar()
self.EnableCNNButtons()
self.m.UpdatePlots()
self.ClearStatusBar()
self.EnableCNNButtons()
self.Enable(self.button_cnn)
def PredThreshSeg(self, timeindex, fovindex, thr_val, seg_val): def PredThreshSeg(self, timeindex, fovindex, thr_val, seg_val):
""" """
...@@ -833,24 +841,25 @@ class App(QMainWindow): ...@@ -833,24 +841,25 @@ class App(QMainWindow):
self.reader.SaveThresholdMask(timeindex, fovindex, self.m.ThresholdMask) self.reader.SaveThresholdMask(timeindex, fovindex, self.m.ThresholdMask)
self.m.SegmentedMask = self.reader.Segment(seg_val, timeindex,fovindex) self.m.SegmentedMask = self.reader.Segment(seg_val, timeindex,fovindex)
self.reader.SaveSegMask(timeindex, fovindex, self.m.SegmentedMask) self.reader.SaveSegMask(timeindex, fovindex, self.m.SegmentedMask)
self.reader.SaveMask(timeindex, fovindex, self.m.SegmentedMask)
def LaunchPrediction(self): # def LaunchPrediction(self):
"""This function is not used in the gui, but it can be used to launch # """This function is not used in the gui, but it can be used to launch
the prediction of one picture, with no thresholding and no segmentation # the prediction of one picture, with no thresholding and no segmentation
""" # """
if not(self.reader.TestPredExisting(self.Tindex, self.FOVindex)): # if not(self.reader.TestPredExisting(self.Tindex, self.FOVindex)):
self.WriteStatusBar('Running the neural network...') # self.WriteStatusBar('Running the neural network...')
self.Disable(self.button_cnn) # self.Disable(self.button_cnn)
self.reader.LaunchPrediction(self.Tindex, self.FOVindex) # self.reader.LaunchPrediction(self.Tindex, self.FOVindex)
#
self.Enable(self.button_cnn) # self.Enable(self.button_cnn)
#
self.button_cnn.setEnabled(False) # self.button_cnn.setEnabled(False)
self.button_threshold.setEnabled(True) # self.button_threshold.setEnabled(True)
self.button_segment.setEnabled(True) # self.button_segment.setEnabled(True)
self.button_cellcorespondance.setEnabled(True) # self.button_cellcorespondance.setEnabled(True)
self.ClearStatusBar() # self.ClearStatusBar()
def SelectChannel(self, index): def SelectChannel(self, index):
...@@ -1061,19 +1070,19 @@ class App(QMainWindow): ...@@ -1061,19 +1070,19 @@ class App(QMainWindow):
def CellCorrespActivation(self): def CellCorrespActivation(self):
self.Disable(self.button_cellcorespondance) self.Disable(self.button_cellcorespondance)
self.WriteStatusBar('Doing the cell correspondance') self.WriteStatusBar('Doing the cell correspondance')
if self.Tindex != 0: if self.Tindex != 0:
self.m.plotmask = self.reader.CellCorrespondance(self.Tindex, self.FOVindex) self.m.plotmask = self.reader.CellCorrespondance(self.Tindex, self.FOVindex)
self.m.updatedata() self.m.updatedata()
else: else:
self.m.plotmask = self.reader.LoadSeg(self.Tindex, self.FOVindex) self.m.plotmask = self.reader.LoadSeg(self.Tindex, self.FOVindex)
self.m.updatedata() self.m.updatedata()
self.Enable(self.button_cellcorespondance) self.Enable(self.button_cellcorespondance)
self.button_cellcorespondance.setChecked(False) self.button_cellcorespondance.setChecked(False)
self.ClearStatusBar() self.ClearStatusBar()
def SegmentBoxCheck(self): def SegmentBoxCheck(self):
......
...@@ -249,12 +249,13 @@ class Reader: ...@@ -249,12 +249,13 @@ class Reader:
else: else:
zeroarray = np.zeros([self.sizey, self.sizex],dtype = np.uint16) zeroarray = np.zeros([self.sizey, self.sizex],dtype = np.uint16)
file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = zeroarray, compression = 'gzip') file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]),
data = zeroarray, compression = 'gzip')
file.close() file.close()
return zeroarray return zeroarray
def TestTimeExist(self,currentT, currentFOV, file): def TestTimeExist(self, currentT, currentFOV, file):
"""This method tests if the array which is requested by LoadMask """This method tests if the array which is requested by LoadMask
already exists or not in the hdf file. already exists or not in the hdf file.
""" """
...@@ -476,19 +477,19 @@ class Reader: ...@@ -476,19 +477,19 @@ class Reader:
def LaunchPrediction(self, currentT, currentFOV): def LaunchPrediction(self, currentT, currentFOV):
"""It launches the neural neutwork on the current image and creates """It launches the neural neutwork on the current image and creates
an hdf file with the prediction for the time T and corresponding FOV. an hdf file with the prediction for the time T and corresponding FOV.
""" """
file = h5py.File(self.predictname, 'r+') file = h5py.File(self.predictname, 'r+')
im = self.LoadOneImage(currentT, currentFOV) if not self.TestTimeExist(currentT, currentFOV, file):
im = skimage.exposure.equalize_adapthist(im) im = self.LoadOneImage(currentT, currentFOV)
im = im*1.0; im = skimage.exposure.equalize_adapthist(im)
pred = nn.prediction(im) im = im*1.0;
file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], pred = nn.prediction(im)
self.tlabels[currentT]), data = pred, compression = 'gzip', file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV],
compression_opts = 7) self.tlabels[currentT]), data = pred,
compression = 'gzip', compression_opts = 7)
file.close() file.close()
......
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 12 10:03:28 2019
Cell Correspondance
"""
import numpy as np
import matplotlib.pyplot as plt
# THIS IS NOT USED ANYMORE
def cell_frame_intersection(nextmask, prevmask, flag):
"""
This function computes the intersection between the segmented frame and
the previous frame. It looks at all the segments in the new mask
(= nextmask) and for each of the segments it finds the corresponding
coordinates. These coordinates are then plugged into the previous mask
and in the corresponding region of the previous mask it takes the cell
which has the most pixel values in this region (compare the pixel counts).
And then takes the value of the value which has the biggest number
of counts in the corresponding region and sets to the segment in the new
mask.
If the cell value given by the intersection between the frame is = 0,
than it means that the new segment intersects mostly with background,
so it is assumed that a new cell is created and a new value is given
to this cell.
of course, it does not work for all the cells (some move, some grow) so
the intersection might occur with
"""
nwcells = np.unique(nextmask)
nwcells = nwcells[nwcells >0]
vals = np.empty(len(nwcells), dtype=int)
vals[:] = -1
coord = np.empty(len(nwcells))
coord = list(coord)
maxvalue = np.amax(prevmask)
double_smaller = []
double_bigger = []
for i, cell in enumerate(nwcells):
v, c = np.unique(prevmask[nextmask==cell], return_counts=True)
valtemp = v[np.argmax(c)]
coord[i] = nextmask == cell
if valtemp != 0:
if valtemp in vals:
tempind = []
cts = []
for k, allval in enumerate(vals):
if allval == valtemp:
valbool, ctmp = np.unique(coord[k], return_counts = True)
cts.append(ctmp[valbool == True])
tempind.append(k)
# cts = np.array(cts)
maxpreviouscts = int(np.amax(cts))
biggestcoord_index = tempind[np.argmax(cts)]
c = int(c[np.argmax(c)])
if len(tempind) == 1:
if c > maxpreviouscts:
double_smaller.append(coord[biggestcoord_index])
double_bigger.append(nextmask == cell)
else:
double_smaller.append(nextmask == cell)
double_bigger.append(coord[biggestcoord_index])
else:
if maxpreviouscts >= c:
double_smaller.append(nextmask == cell)
else:
for k in range(0, len(double_bigger)):
if not(False in (double_bigger[k] == coord[biggestcoord_index])):
double_smaller.append(double_bigger[k])
double_bigger[k] = coord[biggestcoord_index]
vals[i] = valtemp
else:
if flag:
maxvalue = maxvalue + 1
vals[i] = maxvalue
else:
vals[i] = 0
out = nextmask.copy()
for k, v in zip(nwcells, vals):
out[k==nextmask] = v
return out, double_bigger, double_smaller
def CellCorrespondance(nextmask, prevmask, returnbool):
nm = nextmask.copy()
pm = prevmask.copy()
NotifyRegionMask = nextmask.copy()
NotifyRegionMask[NotifyRegionMask > 0] = 0
CorrespondanceMask, bigcluster, smallcluster = cell_frame_intersection(nm, pm, returnbool)
cm = CorrespondanceMask.copy()
oldcells = np.unique(prevmask)
oldcells = oldcells[oldcells > 0]
for cellval in oldcells:
if not(cellval in CorrespondanceMask):
NotifyRegionMask[prevmask == cellval] = 1
for coordinates in smallcluster:
val = np.unique(prevmask[coordinates])
for cellval in oldcells:
if not(cellval in CorrespondanceMask):
if cellval in val:
CorrespondanceMask[coordinates] = cellval
NotifyRegionMask[coordinates] = 1
NotifyRegionMask[coordinates] = 1
return CorrespondanceMask, NotifyRegionMask
def CellCorrespondancePlusTheReturn(nextmask, prevmask):
firstcorresp, notifymask = CellCorrespondance(nextmask, prevmask, True)
firstmask = firstcorresp.copy()
pm = prevmask.copy()
fm = firstcorresp.copy()
# pm2 = prevmask.copy()
cmreturn, bigclustr, smallclustr = cell_frame_intersection(pm, firstmask, False)
# cmreturn, notifymask = CellCorrespondance(pm, firstmask, False)
oldcells = np.unique(fm)
oldcells = oldcells[oldcells > 0]
for cellval in oldcells:
if not(cellval in cmreturn):
notifymask[fm == cellval] = 2
for coordinates in smallclustr:
val = np.unique(fm[coordinates])
for cellval in oldcells:
if not(cellval in cmreturn):
if cellval in val:
notifymask[coordinates] = 2
notifymask[coordinates] = 2
return fm, notifymask
"""
Source of the code: https://github.com/zhixuhao/unet
"""
from __future__ import print_function
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
import glob
import skimage.io as io
import skimage.transform as trans
Sky = [128,128,128]
Building = [128,0,0]
Pole = [192,192,128]
Road = [128,64,128]
Pavement = [60,40,222]
Tree = [128,128,0]
SignSymbol = [192,128,128]
Fence = [64,64,128]
Car = [64,0,128]
Pedestrian = [64,64,0]
Bicyclist = [0,128,192]
Unlabelled = [0,0,0]
COLOR_DICT = np.array([Sky, Building, Pole, Road, Pavement,
Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])
def adjustData(img,mask,flag_multi_class,num_class):
if(flag_multi_class):
img = img / 255
mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
new_mask = np.zeros(mask.shape + (num_class,))
for i in range(num_class):
#for one pixel in the image, find the class in mask and convert it into one-hot vector
#index = np.where(mask == i)
#index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i)
#new_mask[index_mask] = 1
new_mask[mask == i,i] = 1
new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2]))
mask = new_mask
elif(np.max(img) > 1):
img = img / 255
mask = mask /255
mask[mask > 0.5] = 1
mask[mask <= 0.5] = 0
return (img,mask)
def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale",
mask_color_mode = "grayscale",image_save_prefix = "image",mask_save_prefix = "mask",
flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
'''
can generate image and mask at the same time
use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
if you want to visualize the results of generator, set save_to_dir = "your path"
'''
image_datagen = ImageDataGenerator(**aug_dict)
mask_datagen = ImageDataGenerator(**aug_dict)
image_generator = image_datagen.flow_from_directory(
train_path,
classes = [image_folder],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
seed = seed)
mask_generator = mask_datagen.flow_from_directory(
train_path,
classes = [mask_folder],
class_mode = None,
color_mode = mask_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = mask_save_prefix,
seed = seed)
train_generator = zip(image_generator, mask_generator)
for (img,mask) in train_generator:
img,mask = adjustData(img,mask,flag_multi_class,num_class)
yield (img,mask)
def testGenerator(test_path,num_image = 30,target_size = (256,256),
flag_multi_class = False,as_gray = True):
for i in range(num_image):
img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = as_gray)
img = img / 255
img = trans.resize(img,target_size)
img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img
img = np.reshape(img,(1,)+img.shape)
yield img
def geneTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True):
image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix))
image_arr = []
mask_arr = []
for index,item in enumerate(image_name_arr):
img = io.imread(item,as_gray = image_as_gray)
img = np.reshape(img,img.shape + (1,)) if image_as_gray else img
mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray)
mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask
img,mask = adjustData(img,mask,flag_multi_class,num_class)
image_arr.append(img)
mask_arr.append(mask)
image_arr = np.array(image_arr)
mask_arr = np.array(mask_arr)
return image_arr,mask_arr
def labelVisualize(num_class,color_dict,img):
img = img[:,:,0] if len(img.shape) == 3 else img
img_out = np.zeros(img.shape + (3,))
for i in range(num_class):
img_out[img == i,:] = color_dict[i]
return img_out / 255
def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2):
for i,item in enumerate(npyfile):
img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]
io.imsave(os.path.join(save_path,"%d_predict.png"%i),img)
import numpy as np
import skimage
from skimage import io
from skimage.util import img_as_ubyte
from skimage import morphology
#############
# #
# READING #
# #
#############
def read_im_tiff(path, num_frames=None):
"""
Read a multiple page tiff images file, adapt the contrast and stock each
frame in a np.array
Param:
path: path to the tiff movie
num_frames: (integer) number of frames to read
Return:
images:
"""
ims = io.imread(path)
images = []
if(num_frames == None):
num_frames = ims.shape[0]
for i in range(num_frames):
im = skimage.exposure.equalize_adapthist(ims[i])
images.append(np.array(im))
return images
def read_lab_tiff(path, num_frames=None):
"""
Read a multiple page tiff label file and stock each frame in a np.array
Param:
path: path to the tiff movie
num_frames: (integer) number of frames to read
Return:
images: (np.array) array containing the desired frames
"""
ims = io.imread(path)
images = []
if(num_frames == None):
num_frames = ims.shape[0]
for i in range(num_frames):
images.append(np.array(ims[i]))
return images
################################
# #
# GENERAL PROCESSING METHODS #
# #
################################
def pad(im, size):
"""
Carry a mirror padding on an image to end up with a squared image dividable in multiple tiles of shape size x size pixels
Param:
im: (np.array) input image
size: (integer) size of the tiles
Return:
out: (np.array) output image reshaped to the good dimension
"""
# add mirror part along x-axis
nx = im.shape[1]
ny = im.shape[0]
if(nx % size != 0.0):
restx = int(size - nx % size)
outx = np.pad(im, ((0,0), (0,restx)), 'symmetric')
else:
outx = im
if(ny % size != 0.0):
resty = int(size - ny % size)
out = np.pad(outx, ((0,resty), (0,0)), 'symmetric')
else:
out = outx
return out
def threshold(im,th = None):
"""
Binarize an image with a threshold given by the user, or if the threshold is None, calculate the better threshold with isodata
Param:
im: a numpy array image (numpy array)
th: the value of the threshold (feature to select threshold was asked by the lab)
Return:
bi: threshold given by the user (numpy array)
"""
if th == None:
th = skimage.filters.threshold_isodata(im)
bi = im
bi[bi > th] = 255
bi[bi <= th] = 0
return bi
def edge_detection(im):
"""
Detect the edges on a label image
Param:
im: (np.array) input image
Return:
contour: (np.array) output image containing the edges of the input image
"""
contour = np.zeros((im.shape[0], im.shape[1]))
vals = np.unique(im)
for i in range(0,vals.shape[0]):
a = np.zeros((im.shape[0], im.shape[1]))
a[im == vals[i]] = 255
dilated = skimage.morphology.dilation(a, selem=None, out=None)
eroded = skimage.morphology.erosion(a, selem=None, out=None)
edge = dilated - eroded
contour = contour + edge
contour[contour >= 255] = 255
return contour
def split(im, size):
"""
split an squared image with good dimensions (output of pad()) into tiles of shape size x size pixels
Param:
im: (np.array) input image
size: (integer) size of the tiles
Return:
ims: (list of np.array) output tiles
"""
nx = im.shape[1]
ny = im.shape[0]
k_max = int(nx / size) # number of 256 slices along x axis
l_max = int(ny / size) # number of 256 slices along y axis
ims = []
for l in range(0, l_max):
for k in range(0, k_max):
frame = np.zeros((size,size))
lo_x = size * k
hi_x = size * k + size
lo_y = size * l
hi_y = size * l + size
frame = im[lo_y:hi_y, lo_x:hi_x]
# padding of the image to avoid border artefacts due to the convolutions
# out = np.pad(frame, ((10,10), (10,10)), 'symmetric')
ims.append(frame)
return ims
def split_data(im, lab, ratio, seed):
"""split the dataset based on the split ratio."""
# set seed
np.random.seed(seed)
index= np.arange(len(im))
np.random.shuffle(index)
num = int(ratio*len(index))
im = im[index]
lab = lab[index]
im_tr = im[0:num,:,:]
lab_tr = lab[0:num,:,:]
im_te = im[num:,:,:]
lab_te = lab[num:,:,:]
return im_tr, lab_tr, im_te, lab_te
####################################
# #
# TRAIN AND TEST SETS GENERATORS #
# #
####################################
def generate_test_set(im, out_im_path):
"""
Generate the testing set from raw data
Param:
im: (np.array) input image
out_im_path: path to save the testing image
Return:
img_num: (integer) number of image produced
resized_shape: (tupple) shape of the padded image
original_shape: (tupple) shape of the original image
"""
im = skimage.exposure.equalize_adapthist(im)
original_shape = im.shape
#resizing the input images
padded = pad(im, 236)
resized_shape = padded.shape
# splitting
splited = split(padded, 236)
# padding the tiles to get 256x256 tiles
padded_split = []
for tile in splited:
padded_split.append(np.pad(tile, ((10,10), (10,10)), 'symmetric'))
img_num = len(padded_split)
#saving the ouput images
for i in range(len(padded_split)):
name = str(i)
io.imsave( out_im_path + name + ".png", img_as_ubyte(padded_split[i]) )
return img_num, resized_shape, original_shape
def generate_tr_val_set(im_col, lab_col, tr_im_path, tr_lab_path, val_im_path, val_lab_path):
"""
Randomly generate training, validation and testing set from a given collection of images and labels with a 50/25/25 ratio
for detection of whole cells
Params:
im_col: (list of string) list of images (tiff movies) to include in the sets
lab_col: (list of string) list of labels (tiff movies) to include in the sets
tr_im_path: path to save training images
tr_lab_path: path to save training labels
val_im_path: path to save validation images
val_lab_path: path to save validation labels
Returns:
tr_len: (integer) number of samples in the training set
val_len: (integer) number of samples in the validation set
"""
# reading raw data
ims = []
labs = []
for im in im_col:
print('im', im)
ims = ims + read_im_tiff(im)
for lab in lab_col:
print('label',lab)
labs = labs + read_lab_tiff(lab)
ims_out = []
labs_out = []
for i in range( len(ims) ):
# resizing images
im_out = pad(ims[i], 236)
# resizing and binarizing whole cell label
threshold(labs[i],0)
lab_out = pad(labs[i], 236)
# splitting the images
split_im = split(im_out, 236)
split_lab = split(lab_out, 236)
# discarding images showing background only
# padding the tiles
split_im_out = []
split_lab_out = []
for j in range( len(split_lab) ):
if( np.sum(split_lab[j]) > 0.1 * 255 * 256 * 256 ):
split_im_out.append(np.pad(split_im[j], ((10,10), (10,10)), 'symmetric'))
split_lab_out.append(np.pad(split_lab[j], ((10,10), (10,10)), 'symmetric'))
ims_out = ims_out + split_im_out
labs_out = labs_out + split_lab_out
# splitting the list into multiple sets
im_tr, lab_tr, im_val, lab_val = split_data(np.array(ims_out),
np.array(labs_out),
0.75,
1)
tr_len = im_tr.shape[0]
val_len = im_val.shape[0]
#saving the images
for i in range(tr_len):
io.imsave(tr_im_path + str(i) + ".png", img_as_ubyte(im_tr[i,:,:]))
io.imsave(tr_lab_path + str(i) + ".png", (lab_tr[i,:,:]))
for j in range(val_len):
io.imsave(val_im_path + str(j) + ".png", img_as_ubyte(im_val[j,:,:]))
io.imsave(val_lab_path + str(j) + ".png", (lab_val[j,:,:]))
return tr_len, val_len
def reconstruct_result(tile_size, result, resized_shape, origin_shape):
"""
Assemble a set of tiles to reconstruct the original, unsplitted image
Param:
tile_size: (integer) size of the tiles for the reconstruction
result: (np.array) result images of the network prediction
out_result_path: path to save the results
resized_shape: (tuple) size of the image padded for the splitting
origin_shape: (tuple) size or the raw images
Return:
out: (np.array) array containing the reconstructed images
"""
nx, ny = int(resized_shape[1] / tile_size), int(resized_shape[0] / tile_size)
out = np.empty(resized_shape)
i = 0
for l in range(ny):
for k in range(nx):
lo_x = tile_size * k
hi_x = tile_size * k + tile_size
lo_y = tile_size * l
hi_y = tile_size * l + tile_size
out[lo_y:hi_y, lo_x:hi_x] = result[i,:,:]
i = i+1
return out[ 0:origin_shape[0], 0:origin_shape[1] ]
...@@ -6,9 +6,10 @@ Created on Sat Dec 21 18:54:10 2019 ...@@ -6,9 +6,10 @@ Created on Sat Dec 21 18:54:10 2019
""" """
import os import os
from model import unet from model import unet
import data
import numpy as np import numpy as np
import skimage import skimage
from skimage import io
import skimage.transform as trans
def create_directory_if_not_exists(path): def create_directory_if_not_exists(path):
...@@ -57,9 +58,9 @@ def prediction(im): ...@@ -57,9 +58,9 @@ def prediction(im):
create_directory_if_not_exists(path_test) create_directory_if_not_exists(path_test)
# WHOLE CELL PREDICTION # WHOLE CELL PREDICTION
testGene = data.testGenerator(path_test, testGene = testGenerator(path_test,
1, 1,
target_size = (2048,2048)) target_size = (2048,2048))
model = unet(pretrained_weights = None, model = unet(pretrained_weights = None,
input_size = (2048,2048,1)) input_size = (2048,2048,1))
...@@ -79,3 +80,13 @@ def prediction(im): ...@@ -79,3 +80,13 @@ def prediction(im):
return res return res
def testGenerator(test_path,num_image = 30,target_size = (256,256),
flag_multi_class = False,as_gray = True):
for i in range(num_image):
img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = as_gray)
img = img / 255
img = trans.resize(img,target_size)
img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img
img = np.reshape(img,(1,)+img.shape)
yield img
"""
Source of the code: https://github.com/mattminder/yeastSegHelpers
"""
import numpy as np
def cell_correspondance_frame(truth, pred):
"""
Finds for one frame the correspondance between the true and the predicted cells.
Returns dictionary from predicted cell to true cell.
"""
keys = np.unique(pred)
vals = np.empty(len(keys), dtype=int)
for i, cell in enumerate(keys):
v, c = np.unique(truth[pred==cell], return_counts=True)
vals[i] = v[np.argmax(c)]
return dict(zip(keys, vals))
def split_dict(cc):
"""
Returns dictionary of cells in predicted image, to whether they were split unnecessarily
(includes cells that were assigned to background)
"""
truth = list(cc.values())
pred = list(cc.keys())
v, c = np.unique(list(truth), return_counts=True)
split_truth = v[c>1]
return dict(zip(pred, np.isin(truth, split_truth)))
def fused_dict(cc, truth, pred):
"""
Returns dictionary of cells in predicted image to whether they were fused unnecessarily.
"""
rev_cc = cell_correspondance_frame(pred, truth)
rev_split_dict = split_dict(rev_cc)
out = dict()
for p in cc:
out[p] = rev_split_dict[cc[p]]
return out
def nb_splits(cc):
"""
Counts the number of unnecessary splits in the predicted image (ignores background).
"""
v, c = np.unique(list(cc.values()), return_counts=True)
return (c[v>0]-1).sum()
def nb_fusions(cc, truth):
"""
Counts number of cells that were fused.
"""
true_cells = set(np.unique(truth))
pred_cells = set(list(cc.values()))
return len(true_cells - pred_cells)
def nb_artefacts(cc):
"""
Number of predicted cells that are really background.
"""
valarr = np.array(list(cc.values()))
keyarr = np.array(list(cc.keys()))
return (valarr[keyarr>0]==0).sum()
def nb_false_negatives(truth, pred):
"""
Number of cells of mask that were detected as background in the prediction.
"""
rev_cc = cell_correspondance_frame(pred, truth)
return nb_artefacts(rev_cc)
def over_undershoot(truth, pred, cc, look_at):
"""
Calculates average number of pixels that were overshot by cell. Uses look_at, which
is a dict from pred cells to whether they should be considered (allows to exclude wrongly
fused or split cells)
"""
overshoot = 0
undershoot = 0
cellcount = 0
for p in cc:
if not look_at[p]:
continue
t = cc[p]
if t==0 or p==0: # Disregard if truth or prediction is background
continue
cellcount += 1
overshoot += (pred[truth!=t]==p).sum()
undershoot += (truth[pred!=p]==t).sum()
return overshoot/cellcount, undershoot/cellcount
def average_pred_area(pred, cc, look_at):
"""
Calculates average predicted area of all considered cells
"""
area = 0
cellcount = 0
for p in cc:
if not look_at[p]:
continue
if p==0:
continue
cellcount += 1
area += (pred==p).sum()
return area/cellcount
def average_true_area(truth, cc, look_at):
"""
Calculates average true area of all considered cells
"""
area = 0
cellcount = 0
for p in cc:
if not look_at[p]:
continue
t = cc[p]
if t==0:
continue
cellcount += 1
area += (truth==t).sum()
return area/cellcount
def n_considered_cells(cc, look_at):
"""
Calculates the number of considered cells
"""
count = 0
for p in cc:
if not look_at[p]:
continue
if p==0 or cc[p]==0:
continue
count += 1
return count
def quality_measures(truth, pred):
"""
Tests quality of prediction with four statistics:
Number Fusions: How many times are multiple true cells predicted as a single cell?
Number Splits: How many times is a single true cell split into multiple predicted cell?
Av. Overshoot: How many pixels are wrongly predicted to belong to the cell on average
Av. Undershoot: How many pixels are wrongly predicted to not belong to the cell on average
Av. true area: Average area of cells of truth that are neither split nor fused
Av. pred area: Average area of predicted cells that are neither split nor fused
Nb considered cells:Number of cells that are neither split nor fused
"""
cc = cell_correspondance_frame(truth, pred)
res = dict()
# Get indices of cells that are not useable for counting under- and overshooting
is_split = split_dict(cc)
is_fused = fused_dict(cc, truth, pred)
look_at = dict()
for key in is_split:
look_at[key] = not (is_split[key] or is_fused[key])
# Result Calculation
res["Number Fusions"] = nb_fusions(cc, truth)
res["Number Splits"] = nb_splits(cc)
res["Nb False Positives"] = nb_artefacts(cc)
res["Nb False Negatives"] = nb_false_negatives(truth, pred)
res["Average Overshoot"], res["Average Undershoot"] = over_undershoot(truth, pred, cc, look_at)
res["Av. True Area"] = average_true_area(truth, cc, look_at)
res["Av. Pred Area"] = average_pred_area(pred, cc, look_at)
res["Nb Considered Cells"] = n_considered_cells(cc, look_at)
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment