diff --git a/GUI_main.py b/GUI_main.py
index 223e3faf9525415fbc23b10ccfc682fc40713539..a3ee0b37d33ebaaf7b17a92322a747ff2a75b6c6 100644
--- a/GUI_main.py
+++ b/GUI_main.py
@@ -757,6 +757,10 @@ class App(QMainWindow):
         inside of self.PredThreshSeg and it does the prediction of the neural
         network, thresholds this prediction and then segments it.
         """
+        
+        self.WriteStatusBar('Running the neural network...')
+        self.Disable(self.button_cnn)
+
         # creates a dialog window from the LaunchBatchPrediction.py file
         dlg = lbp.CustomDialog(self)
         
@@ -814,10 +818,14 @@ class App(QMainWindow):
                             temp_mask = self.reader.LoadSeg(t, dlg.listfov.row(item))
                             self.reader.SaveMask(t,dlg.listfov.row(item), temp_mask)
             
-            self.m.UpdatePlots()
-            self.ClearStatusBar()
-            self.EnableCNNButtons()
-   
+            self.ReloadThreeMasks()
+            
+            
+        self.m.UpdatePlots()
+        self.ClearStatusBar()
+        self.EnableCNNButtons()
+        self.Enable(self.button_cnn)
+    
     
     def PredThreshSeg(self, timeindex, fovindex, thr_val, seg_val):
           """
@@ -833,24 +841,25 @@ class App(QMainWindow):
           self.reader.SaveThresholdMask(timeindex, fovindex, self.m.ThresholdMask)
           self.m.SegmentedMask = self.reader.Segment(seg_val, timeindex,fovindex)
           self.reader.SaveSegMask(timeindex, fovindex, self.m.SegmentedMask)
+          self.reader.SaveMask(timeindex, fovindex, self.m.SegmentedMask)
     
     
-    def LaunchPrediction(self):
-        """This function is not used in the gui, but it can be used to launch
-        the prediction of one picture, with no thresholding and no segmentation
-        """
-        if not(self.reader.TestPredExisting(self.Tindex, self.FOVindex)):
-            self.WriteStatusBar('Running the neural network...')
-            self.Disable(self.button_cnn)
-            self.reader.LaunchPrediction(self.Tindex, self.FOVindex)
-            
-            self.Enable(self.button_cnn)
-            
-            self.button_cnn.setEnabled(False)
-            self.button_threshold.setEnabled(True)
-            self.button_segment.setEnabled(True)
-            self.button_cellcorespondance.setEnabled(True)
-            self.ClearStatusBar()
+#    def LaunchPrediction(self):
+#        """This function is not used in the gui, but it can be used to launch
+#        the prediction of one picture, with no thresholding and no segmentation
+#        """
+#        if not(self.reader.TestPredExisting(self.Tindex, self.FOVindex)):
+#            self.WriteStatusBar('Running the neural network...')
+#            self.Disable(self.button_cnn)
+#            self.reader.LaunchPrediction(self.Tindex, self.FOVindex)
+#            
+#            self.Enable(self.button_cnn)
+#            
+#            self.button_cnn.setEnabled(False)
+#            self.button_threshold.setEnabled(True)
+#            self.button_segment.setEnabled(True)
+#            self.button_cellcorespondance.setEnabled(True)
+#            self.ClearStatusBar()
         
     
     def SelectChannel(self, index):
@@ -1061,19 +1070,19 @@ class App(QMainWindow):
         
      
     def CellCorrespActivation(self):
-            self.Disable(self.button_cellcorespondance)
-            self.WriteStatusBar('Doing the cell correspondance')
+        self.Disable(self.button_cellcorespondance)
+        self.WriteStatusBar('Doing the cell correspondance')
 
-            if self.Tindex != 0:
-                self.m.plotmask = self.reader.CellCorrespondance(self.Tindex, self.FOVindex)
-                self.m.updatedata()
-            else:
-                self.m.plotmask = self.reader.LoadSeg(self.Tindex, self.FOVindex)
-                self.m.updatedata()
+        if self.Tindex != 0:
+            self.m.plotmask = self.reader.CellCorrespondance(self.Tindex, self.FOVindex)
+            self.m.updatedata()
+        else:
+            self.m.plotmask = self.reader.LoadSeg(self.Tindex, self.FOVindex)
+            self.m.updatedata()
 
-            self.Enable(self.button_cellcorespondance)
-            self.button_cellcorespondance.setChecked(False)
-            self.ClearStatusBar()
+        self.Enable(self.button_cellcorespondance)
+        self.button_cellcorespondance.setChecked(False)
+        self.ClearStatusBar()
         
         
     def SegmentBoxCheck(self):
diff --git a/disk/InteractionDisk_temp.py b/disk/InteractionDisk_temp.py
index ee6f1d1040f694831f01adfda1098501dbeca82f..6e57a284da84f4a46e074cb580c3414c7d521bea 100644
--- a/disk/InteractionDisk_temp.py
+++ b/disk/InteractionDisk_temp.py
@@ -249,12 +249,13 @@ class Reader:
         
         else:
             zeroarray = np.zeros([self.sizey, self.sizex],dtype = np.uint16)
-            file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = zeroarray, compression = 'gzip')
+            file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), 
+                                data = zeroarray, compression = 'gzip')
             file.close()
             return zeroarray
             
             
-    def TestTimeExist(self,currentT, currentFOV, file):
+    def TestTimeExist(self, currentT, currentFOV, file):
         """This method tests if the array which is requested by LoadMask
         already exists or not in the hdf file.
         """
@@ -476,19 +477,19 @@ class Reader:
         
         
     def LaunchPrediction(self, currentT, currentFOV):
-        
         """It launches the neural neutwork on the current image and creates 
         an hdf file with the prediction for the time T and corresponding FOV. 
         """
 
-        file = h5py.File(self.predictname, 'r+')        
-        im = self.LoadOneImage(currentT, currentFOV)
-        im = skimage.exposure.equalize_adapthist(im)
-        im = im*1.0;	
-        pred = nn.prediction(im)
-        file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], 
-                                    self.tlabels[currentT]), data = pred, compression = 'gzip', 
-                                    compression_opts = 7)
+        file = h5py.File(self.predictname, 'r+') 
+        if not self.TestTimeExist(currentT, currentFOV, file):
+            im = self.LoadOneImage(currentT, currentFOV)
+            im = skimage.exposure.equalize_adapthist(im)
+            im = im*1.0;	
+            pred = nn.prediction(im)
+            file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], 
+                                self.tlabels[currentT]), data = pred, 
+                                compression = 'gzip', compression_opts = 7)
         file.close()
             
 
diff --git a/unet/CellCorrespondance.py b/unet/CellCorrespondance.py
deleted file mode 100644
index a72fbcc19670a904c5ae53f0d6677c13d1c759c9..0000000000000000000000000000000000000000
--- a/unet/CellCorrespondance.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Thu Dec 12 10:03:28 2019
-
-Cell Correspondance
-"""
-
-import numpy as np
-import matplotlib.pyplot as plt
-
-
-# THIS IS NOT USED ANYMORE
-
-
-
-
-
-def cell_frame_intersection(nextmask, prevmask, flag):
-    """
-    This function computes the intersection between the segmented frame and
-    the previous frame. It looks at all the segments in the new mask
-    (= nextmask) and for each of the segments it finds the corresponding 
-    coordinates. These coordinates are then plugged into the previous mask
-    and in the corresponding region of the previous mask it takes the cell
-    which has the most pixel values in this region (compare the pixel counts).
-    And then takes the value of the value which has the biggest number
-    of counts in the corresponding region and sets to the segment in the new 
-    mask. 
-    If the cell value given by the intersection between the frame is = 0, 
-    than it means that the new segment intersects mostly with background,
-    so it is assumed that a new cell is created and a new value is given
-    to this cell.
-    of course, it does not work for all the cells (some move, some grow) so 
-    the intersection might occur with 
-    """
-    nwcells = np.unique(nextmask)
-    nwcells = nwcells[nwcells >0]
-    vals = np.empty(len(nwcells), dtype=int)
-    vals[:] = -1
-    coord = np.empty(len(nwcells))
-    coord = list(coord)
-    maxvalue = np.amax(prevmask)
-    double_smaller = []
-    double_bigger = []
-    
-    for i, cell in enumerate(nwcells):
-
-            v, c = np.unique(prevmask[nextmask==cell], return_counts=True)
-            
-            valtemp = v[np.argmax(c)]
-            coord[i] = nextmask == cell
-            
-            if valtemp != 0:
-                if valtemp in vals:
-                    tempind = []
-                    cts = []
-                    
-                    for k, allval in enumerate(vals):
-                            if allval == valtemp:
-
-                                    valbool, ctmp = np.unique(coord[k], return_counts = True)
-
-                                    cts.append(ctmp[valbool == True])
-                                    tempind.append(k)
-    
-#                    cts = np.array(cts)
-                    maxpreviouscts = int(np.amax(cts))
-                    biggestcoord_index = tempind[np.argmax(cts)]
-                    c = int(c[np.argmax(c)])
-                        
-                                        
-                    if len(tempind) == 1:
-                        if c > maxpreviouscts:
-                                    double_smaller.append(coord[biggestcoord_index])
-                                    double_bigger.append(nextmask == cell)
-                        else:
-                                    double_smaller.append(nextmask == cell)
-                                    double_bigger.append(coord[biggestcoord_index])
-                    else:
-                       
-                        if maxpreviouscts >= c:
-                            double_smaller.append(nextmask == cell)
-                        
-                        else:
-
-                            for k in range(0, len(double_bigger)):
-                                if not(False in (double_bigger[k] == coord[biggestcoord_index])):
-                                    double_smaller.append(double_bigger[k])
-                                    double_bigger[k] = coord[biggestcoord_index]
-                                    
-                                    
-                    
-                vals[i] = valtemp
-                
-                
-                
-                
-            else:
-                if flag:
-                    maxvalue = maxvalue + 1
-                    vals[i] = maxvalue
-                else:
-                    vals[i] = 0
-     
-    out = nextmask.copy()
-    for k, v in zip(nwcells, vals):
-        out[k==nextmask] = v
-    return out, double_bigger, double_smaller
-
-
-
-
-
- 
-
-
-def CellCorrespondance(nextmask, prevmask, returnbool):
-    
-    nm = nextmask.copy()
-    pm = prevmask.copy()
-        
-    NotifyRegionMask = nextmask.copy()
-    NotifyRegionMask[NotifyRegionMask > 0] = 0
-    
-    CorrespondanceMask, bigcluster, smallcluster = cell_frame_intersection(nm, pm, returnbool)
-    
-    cm = CorrespondanceMask.copy()
-     
-    oldcells = np.unique(prevmask)
-    oldcells = oldcells[oldcells > 0]
-    
-    for cellval in oldcells:
-        if not(cellval in CorrespondanceMask):
-
-            NotifyRegionMask[prevmask == cellval] = 1
-            
-
-    
-    for coordinates in smallcluster:        
-        val = np.unique(prevmask[coordinates])
-        for cellval in oldcells:
-            if not(cellval in CorrespondanceMask):
-                if cellval in val:
-                
-                    CorrespondanceMask[coordinates] = cellval
-                    NotifyRegionMask[coordinates] = 1
-
-                    
-        NotifyRegionMask[coordinates] = 1
-
-
-            
-            
-
-
-    return CorrespondanceMask, NotifyRegionMask
-
-
-
-    
-def CellCorrespondancePlusTheReturn(nextmask, prevmask):
-    
-    firstcorresp, notifymask = CellCorrespondance(nextmask, prevmask, True)
-    
-    firstmask = firstcorresp.copy()
-    pm = prevmask.copy()
-    fm = firstcorresp.copy()
-#    pm2 = prevmask.copy()
-    
-    cmreturn, bigclustr, smallclustr = cell_frame_intersection(pm, firstmask, False)
-    
-#    cmreturn, notifymask = CellCorrespondance(pm, firstmask, False)
-    oldcells = np.unique(fm)
-    oldcells = oldcells[oldcells > 0]
-    
-    for cellval in oldcells:
-        if not(cellval in cmreturn):
-
-            notifymask[fm == cellval] = 2
-
-    
-    
-    for coordinates in smallclustr:
-
-        val = np.unique(fm[coordinates])
-        for cellval in oldcells:
-            if not(cellval in cmreturn):
-                if cellval in val:
-
-                    notifymask[coordinates] = 2
-
-        notifymask[coordinates] = 2
-
-            
-            
-
-
-    return fm, notifymask
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-   
-        
diff --git a/unet/data.py b/unet/data.py
deleted file mode 100644
index 7d11ef30efe40dfb4ac88c72e763f545583d4e2d..0000000000000000000000000000000000000000
--- a/unet/data.py
+++ /dev/null
@@ -1,128 +0,0 @@
-"""
-Source of the code: https://github.com/zhixuhao/unet
-"""
-from __future__ import print_function
-from tensorflow.keras.preprocessing.image import ImageDataGenerator
-import numpy as np
-import os
-import glob
-import skimage.io as io
-import skimage.transform as trans
-
-Sky = [128,128,128]
-Building = [128,0,0]
-Pole = [192,192,128]
-Road = [128,64,128]
-Pavement = [60,40,222]
-Tree = [128,128,0]
-SignSymbol = [192,128,128]
-Fence = [64,64,128]
-Car = [64,0,128]
-Pedestrian = [64,64,0]
-Bicyclist = [0,128,192]
-Unlabelled = [0,0,0]
-
-
-COLOR_DICT = np.array([Sky, Building, Pole, Road, Pavement,
-                          Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])
-
-
-def adjustData(img,mask,flag_multi_class,num_class):
-    if(flag_multi_class):
-        img = img / 255
-        mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
-        new_mask = np.zeros(mask.shape + (num_class,))
-        for i in range(num_class):
-            #for one pixel in the image, find the class in mask and convert it into one-hot vector
-            #index = np.where(mask == i)
-            #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i)
-            #new_mask[index_mask] = 1
-            new_mask[mask == i,i] = 1
-        new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2]))
-        mask = new_mask
-    elif(np.max(img) > 1):
-        img = img / 255
-        mask = mask /255
-        mask[mask > 0.5] = 1
-        mask[mask <= 0.5] = 0
-    return (img,mask)
-
-
-
-def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale",
-                    mask_color_mode = "grayscale",image_save_prefix  = "image",mask_save_prefix  = "mask",
-                    flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
-    '''
-    can generate image and mask at the same time
-    use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
-    if you want to visualize the results of generator, set save_to_dir = "your path"
-    '''
-    image_datagen = ImageDataGenerator(**aug_dict)
-    mask_datagen = ImageDataGenerator(**aug_dict)
-    image_generator = image_datagen.flow_from_directory(
-        train_path,
-        classes = [image_folder],
-        class_mode = None,
-        color_mode = image_color_mode,
-        target_size = target_size,
-        batch_size = batch_size,
-        save_to_dir = save_to_dir,
-        save_prefix  = image_save_prefix,
-        seed = seed)
-    mask_generator = mask_datagen.flow_from_directory(
-        train_path,
-        classes = [mask_folder],
-        class_mode = None,
-        color_mode = mask_color_mode,
-        target_size = target_size,
-        batch_size = batch_size,
-        save_to_dir = save_to_dir,
-        save_prefix  = mask_save_prefix,
-        seed = seed)
-    train_generator = zip(image_generator, mask_generator)
-    for (img,mask) in train_generator:
-        img,mask = adjustData(img,mask,flag_multi_class,num_class)
-        yield (img,mask)
-
-
-def testGenerator(test_path,num_image = 30,target_size = (256,256),
-                  flag_multi_class = False,as_gray = True):
-    for i in range(num_image):
-        img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = as_gray)
-        img = img / 255
-        img = trans.resize(img,target_size)
-        img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img
-        img = np.reshape(img,(1,)+img.shape)
-        yield img
-
-
-def geneTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True):
-    image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix))
-    image_arr = []
-    mask_arr = []
-    for index,item in enumerate(image_name_arr):
-        img = io.imread(item,as_gray = image_as_gray)
-        img = np.reshape(img,img.shape + (1,)) if image_as_gray else img
-        mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray)
-        mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask
-        img,mask = adjustData(img,mask,flag_multi_class,num_class)
-        image_arr.append(img)
-        mask_arr.append(mask)
-    image_arr = np.array(image_arr)
-    mask_arr = np.array(mask_arr)
-    return image_arr,mask_arr
-
-
-def labelVisualize(num_class,color_dict,img):
-    img = img[:,:,0] if len(img.shape) == 3 else img
-    img_out = np.zeros(img.shape + (3,))
-    for i in range(num_class):
-        img_out[img == i,:] = color_dict[i]
-    return img_out / 255
-
-
-
-def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2):
-    for i,item in enumerate(npyfile):
-        img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]
-        io.imsave(os.path.join(save_path,"%d_predict.png"%i),img)
diff --git a/unet/data_processing.py b/unet/data_processing.py
deleted file mode 100644
index c0836f118bd4523fb293601a5dba629dbc9e70c3..0000000000000000000000000000000000000000
--- a/unet/data_processing.py
+++ /dev/null
@@ -1,352 +0,0 @@
-import numpy as np
-
-import skimage
-from skimage import io
-from skimage.util import img_as_ubyte
-from skimage import morphology
-
-
-#############
-#           #
-#  READING  #
-#           #
-#############
-
-def read_im_tiff(path, num_frames=None):
-    """
-    Read a multiple page tiff images file, adapt the contrast and stock each
-    frame in a np.array
-    Param:
-        path: path to the tiff movie
-        num_frames: (integer) number of frames to read
-    Return:
-        images:
-    """
-
-    ims = io.imread(path)
-    images = []
-
-    if(num_frames == None):
-        num_frames = ims.shape[0]
-
-    for i in range(num_frames):
-        im = skimage.exposure.equalize_adapthist(ims[i])
-        images.append(np.array(im))
-
-    return images
-
-
-
-def read_lab_tiff(path, num_frames=None):
-    """
-    Read a multiple page tiff label file and stock each frame in a np.array
-    Param:
-        path: path to the tiff movie
-        num_frames: (integer) number of frames to read
-    Return:
-        images: (np.array) array containing the desired frames
-    """
-    
-    ims = io.imread(path)
-    images = []
-
-    if(num_frames == None):
-        
-        num_frames = ims.shape[0]
-
-    for i in range(num_frames):
-        
-        images.append(np.array(ims[i]))
-
-    return images
-
-
-
-################################
-#                              #
-#  GENERAL PROCESSING METHODS  #
-#                              #
-################################
-
-def pad(im, size):
-    """
-    Carry a mirror padding on an image to end up with a squared image dividable in multiple tiles of shape size x size pixels
-    Param:
-        im: (np.array) input image
-        size: (integer) size of the tiles
-    Return:
-        out: (np.array) output image reshaped to the good dimension
-    """
-
-    # add mirror part along x-axis
-    nx = im.shape[1]
-    ny = im.shape[0]
-
-    if(nx % size != 0.0):
-        restx = int(size - nx % size)
-        outx = np.pad(im, ((0,0), (0,restx)), 'symmetric')
-    else:
-        outx = im
-
-    if(ny % size != 0.0):
-        resty = int(size - ny % size)
-        out = np.pad(outx, ((0,resty), (0,0)), 'symmetric')
-    else:
-        out = outx
-
-    return out
-
-def threshold(im,th = None):
-    """
-    Binarize an image with a threshold given by the user, or if the threshold is None, calculate the better threshold with isodata
-    Param:
-        im: a numpy array image (numpy array)
-        th: the value of the threshold (feature to select threshold was asked by the lab)
-    Return:
-        bi: threshold given by the user (numpy array)
-    """
-    if th == None:
-        th = skimage.filters.threshold_isodata(im)
-    bi = im
-    bi[bi > th] = 255
-    bi[bi <= th] = 0
-    return bi
-
-
-def edge_detection(im):
-    """
-    Detect the edges on a label image
-    Param:
-        im: (np.array) input image
-    Return:
-        contour: (np.array) output image containing the edges of the input image
-    """
-
-    contour = np.zeros((im.shape[0], im.shape[1]))
-    vals = np.unique(im)
-
-    for i in range(0,vals.shape[0]):
-        a = np.zeros((im.shape[0], im.shape[1]))
-        a[im == vals[i]] = 255
-
-        dilated = skimage.morphology.dilation(a, selem=None, out=None)
-        eroded = skimage.morphology.erosion(a, selem=None, out=None)
-        edge = dilated - eroded
-
-        contour = contour + edge
-
-    contour[contour >= 255] = 255
-
-    return contour
-
-
-def split(im, size):
-    """
-    split an squared image with good dimensions (output of pad()) into tiles of shape size x size pixels
-    Param:
-        im: (np.array) input image
-        size: (integer) size of the tiles
-    Return:
-        ims: (list of np.array) output tiles
-    """
-
-    nx = im.shape[1]
-    ny = im.shape[0]
-
-    k_max = int(nx / size)    # number of 256 slices along x axis
-    l_max = int(ny / size)    # number of 256 slices along y axis
-
-    ims = []
-
-    for l in range(0, l_max):
-        for k in range(0, k_max):
-            frame = np.zeros((size,size))
-
-            lo_x = size * k
-            hi_x = size * k + size
-
-            lo_y = size * l
-            hi_y = size * l + size
-
-            frame = im[lo_y:hi_y, lo_x:hi_x]
-
-            # padding of the image to avoid border artefacts due to the convolutions
-            # out = np.pad(frame, ((10,10), (10,10)), 'symmetric')
-
-            ims.append(frame)
-    return ims
-
-
-
-def split_data(im, lab, ratio, seed):
-    """split the dataset based on the split ratio."""
-    # set seed
-    np.random.seed(seed)
-
-    index= np.arange(len(im))
-    np.random.shuffle(index)
-    num = int(ratio*len(index))
-    im = im[index]
-    lab = lab[index]
-
-    im_tr = im[0:num,:,:]
-    lab_tr = lab[0:num,:,:]
-
-    im_te = im[num:,:,:]
-    lab_te = lab[num:,:,:]
-
-    return im_tr, lab_tr, im_te, lab_te
-
-
-
-####################################
-#                                  #
-#  TRAIN AND TEST SETS GENERATORS  #
-#                                  #
-####################################
-
-def generate_test_set(im, out_im_path):
-  """
-  Generate the testing set from raw data
-  Param:
-    im: (np.array) input image
-    out_im_path: path to save the testing image
-
-  Return:
-    img_num: (integer) number of image produced
-    resized_shape: (tupple) shape of the padded image
-    original_shape: (tupple) shape of the original image
-  """
-  im = skimage.exposure.equalize_adapthist(im)
-
-  original_shape = im.shape
-
-  #resizing the input images
-  padded = pad(im, 236)
-  resized_shape = padded.shape
-
-  # splitting
-  splited = split(padded, 236)
-
-  # padding the tiles to get 256x256 tiles
-  padded_split = []
-  for tile in splited:
-    padded_split.append(np.pad(tile, ((10,10), (10,10)), 'symmetric'))
-
-  img_num = len(padded_split)
-
-  #saving the ouput images
-  for i in range(len(padded_split)):
-      name = str(i)
-      io.imsave( out_im_path + name + ".png", img_as_ubyte(padded_split[i]) )
-
-  return img_num, resized_shape, original_shape
-
-
-
-def generate_tr_val_set(im_col, lab_col, tr_im_path, tr_lab_path, val_im_path, val_lab_path):
-    """
-    Randomly generate training, validation and testing set from a given collection of images and labels with a 50/25/25 ratio
-    for detection of whole cells
-    Params:
-        im_col: (list of string) list of images (tiff movies) to include in the sets
-        lab_col: (list of string) list of labels (tiff movies) to include in the sets
-        tr_im_path: path to save training images
-        tr_lab_path: path to save training labels
-        val_im_path: path to save validation images
-        val_lab_path: path to save validation labels
-    Returns:
-        tr_len: (integer) number of samples in the training set
-        val_len: (integer) number of samples in the validation set
-    """
-    # reading raw data
-    ims = []
-    labs = []
-
-    for im in im_col:
-        print('im', im)
-        ims = ims + read_im_tiff(im)
-
-    for lab in lab_col:
-        print('label',lab)
-        labs = labs + read_lab_tiff(lab)
-
-    ims_out = []
-    labs_out = []
-
-
-    for i in range( len(ims) ):
-      # resizing images
-      im_out = pad(ims[i], 236)
-
-      # resizing and binarizing whole cell label
-      
-      threshold(labs[i],0)
-      lab_out = pad(labs[i], 236)
-
-      # splitting the images
-      split_im = split(im_out, 236)
-      split_lab = split(lab_out, 236)
-
-      # discarding images showing background only
-      # padding the tiles
-      split_im_out = []
-      split_lab_out = []
-      for j in range( len(split_lab) ):
-        if( np.sum(split_lab[j]) > 0.1 * 255 * 256 * 256 ):
-          split_im_out.append(np.pad(split_im[j], ((10,10), (10,10)), 'symmetric'))
-          split_lab_out.append(np.pad(split_lab[j], ((10,10), (10,10)), 'symmetric'))
-
-      ims_out = ims_out + split_im_out
-      labs_out = labs_out + split_lab_out
-
-    # splitting the list into multiple sets
-    im_tr, lab_tr, im_val, lab_val = split_data(np.array(ims_out),
-                                                         np.array(labs_out),
-                                                         0.75,
-                                                         1)
-
-    tr_len = im_tr.shape[0]
-    val_len = im_val.shape[0]
-
-    #saving the images
-    for i in range(tr_len):
-        io.imsave(tr_im_path + str(i) + ".png", img_as_ubyte(im_tr[i,:,:]))
-        io.imsave(tr_lab_path + str(i) + ".png", (lab_tr[i,:,:]))
-
-    for j in range(val_len):
-        io.imsave(val_im_path + str(j) + ".png", img_as_ubyte(im_val[j,:,:]))
-        io.imsave(val_lab_path + str(j) + ".png", (lab_val[j,:,:]))
-
-    return tr_len, val_len
-
-def reconstruct_result(tile_size, result, resized_shape, origin_shape):
-    """
-    Assemble a set of tiles to reconstruct the original, unsplitted image
-    Param:
-      tile_size: (integer) size of the tiles for the reconstruction
-      result: (np.array) result images of the network prediction
-      out_result_path: path to save the results
-      resized_shape: (tuple) size of the image padded for the splitting
-      origin_shape: (tuple) size or the raw images
-    Return:
-      out: (np.array) array containing the reconstructed images
-    """
-    nx, ny = int(resized_shape[1] / tile_size), int(resized_shape[0] / tile_size)
-    out = np.empty(resized_shape)
-
-    i = 0
-    for l in range(ny):
-      for k in range(nx):
-
-        lo_x = tile_size * k
-        hi_x = tile_size * k + tile_size
-
-        lo_y = tile_size * l
-        hi_y = tile_size * l + tile_size
-
-        out[lo_y:hi_y, lo_x:hi_x] = result[i,:,:]
-
-        i = i+1
-
-    return out[ 0:origin_shape[0], 0:origin_shape[1] ]
diff --git a/unet/neural_network.py b/unet/neural_network.py
index 05dc9c2f6c03f8f72eafdaf79f643eb805766daa..b0c3906e4afeafb93242152aba824ca280e68884 100644
--- a/unet/neural_network.py
+++ b/unet/neural_network.py
@@ -6,9 +6,10 @@ Created on Sat Dec 21 18:54:10 2019
 """
 import os
 from model import unet
-import data
 import numpy as np
 import skimage
+from skimage import io
+import skimage.transform as trans
 
 
 def create_directory_if_not_exists(path):
@@ -57,9 +58,9 @@ def prediction(im):
     create_directory_if_not_exists(path_test)
 
     # WHOLE CELL PREDICTION
-    testGene = data.testGenerator(path_test,
-                                  1,
-                                  target_size = (2048,2048))
+    testGene = testGenerator(path_test,
+                             1,
+                             target_size = (2048,2048))
 
     model = unet(pretrained_weights = None,
                  input_size = (2048,2048,1))
@@ -79,3 +80,13 @@ def prediction(im):
 
     return res
 
+
+def testGenerator(test_path,num_image = 30,target_size = (256,256),
+                  flag_multi_class = False,as_gray = True):
+    for i in range(num_image):
+        img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = as_gray)
+        img = img / 255
+        img = trans.resize(img,target_size)
+        img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img
+        img = np.reshape(img,(1,)+img.shape)
+        yield img
diff --git a/unet/quality_measures.py b/unet/quality_measures.py
deleted file mode 100644
index b5b7d7cce2fefbda2bd87df4b008f5dac37b8285..0000000000000000000000000000000000000000
--- a/unet/quality_measures.py
+++ /dev/null
@@ -1,166 +0,0 @@
-"""
-Source of the code: https://github.com/mattminder/yeastSegHelpers
-"""
-import numpy as np
-
-def cell_correspondance_frame(truth, pred):
-    """
-    Finds for one frame the correspondance between the true and the predicted cells.
-    Returns dictionary from predicted cell to true cell.
-    """
-    keys = np.unique(pred)
-    vals = np.empty(len(keys), dtype=int)
-    for i, cell in enumerate(keys):
-        v, c = np.unique(truth[pred==cell], return_counts=True)
-        vals[i] = v[np.argmax(c)]
-    return dict(zip(keys, vals))
-
-def split_dict(cc):
-    """
-    Returns dictionary of cells in predicted image, to whether they were split unnecessarily
-    (includes cells that were assigned to background)
-    """
-    truth = list(cc.values())
-    pred = list(cc.keys())
-
-    v, c = np.unique(list(truth), return_counts=True)
-    split_truth = v[c>1]
-    return dict(zip(pred, np.isin(truth, split_truth)))
-
-def fused_dict(cc, truth, pred):
-    """
-    Returns dictionary of cells in predicted image to whether they were fused unnecessarily.
-    """
-    rev_cc = cell_correspondance_frame(pred, truth)
-    rev_split_dict = split_dict(rev_cc)
-    out = dict()
-    for p in cc:
-        out[p] = rev_split_dict[cc[p]]
-    return out
-
-def nb_splits(cc):
-    """
-    Counts the number of unnecessary splits in the predicted image (ignores background).
-    """
-    v, c = np.unique(list(cc.values()), return_counts=True)
-    return (c[v>0]-1).sum()
-
-def nb_fusions(cc, truth):
-    """
-    Counts number of cells that were fused.
-    """
-    true_cells = set(np.unique(truth))
-    pred_cells = set(list(cc.values()))
-    return len(true_cells - pred_cells)
-
-def nb_artefacts(cc):
-    """
-    Number of predicted cells that are really background.
-    """
-    valarr = np.array(list(cc.values()))
-    keyarr = np.array(list(cc.keys()))
-    return (valarr[keyarr>0]==0).sum()
-
-def nb_false_negatives(truth, pred):
-    """
-    Number of cells of mask that were detected as background in the prediction.
-    """
-    rev_cc = cell_correspondance_frame(pred, truth)
-    return nb_artefacts(rev_cc)
-
-def over_undershoot(truth, pred, cc, look_at):
-    """
-    Calculates average number of pixels that were overshot by cell. Uses look_at, which
-    is a dict from pred cells to whether they should be considered (allows to exclude wrongly
-    fused or split cells)
-    """
-    overshoot = 0
-    undershoot = 0
-    cellcount = 0
-
-    for p in cc:
-        if not look_at[p]:
-            continue
-        t = cc[p]
-        if t==0 or p==0: # Disregard if truth or prediction is background
-            continue
-        cellcount += 1
-        overshoot += (pred[truth!=t]==p).sum()
-        undershoot += (truth[pred!=p]==t).sum()
-    return overshoot/cellcount, undershoot/cellcount
-
-def average_pred_area(pred, cc, look_at):
-    """
-    Calculates average predicted area of all considered cells
-    """
-    area = 0
-    cellcount = 0
-    for p in cc:
-        if not look_at[p]:
-            continue
-        if p==0:
-            continue
-        cellcount += 1
-        area += (pred==p).sum()
-    return area/cellcount
-
-def average_true_area(truth, cc, look_at):
-    """
-    Calculates average true area of all considered cells
-    """
-    area = 0
-    cellcount = 0
-    for p in cc:
-        if not look_at[p]:
-            continue
-        t = cc[p]
-        if t==0:
-            continue
-        cellcount += 1
-        area += (truth==t).sum()
-    return area/cellcount
-
-def n_considered_cells(cc, look_at):
-    """
-    Calculates the number of considered cells
-    """
-    count = 0
-    for p in cc:
-        if not look_at[p]:
-            continue
-        if p==0 or cc[p]==0:
-            continue
-        count += 1
-    return count
-
-def quality_measures(truth, pred):
-    """
-    Tests quality of prediction with four statistics:
-    Number Fusions:     How many times are multiple true cells predicted as a single cell?
-    Number Splits:      How many times is a single true cell split into multiple predicted cell?
-    Av. Overshoot:      How many pixels are wrongly predicted to belong to the cell on average
-    Av. Undershoot:     How many pixels are wrongly predicted to not belong to the cell on average
-    Av. true area:      Average area of cells of truth that are neither split nor fused
-    Av. pred area:      Average area of predicted cells that are neither split nor fused
-    Nb considered cells:Number of cells that are neither split nor fused
-    """
-    cc = cell_correspondance_frame(truth, pred)
-    res = dict()
-
-    # Get indices of cells that are not useable for counting under- and overshooting
-    is_split = split_dict(cc)
-    is_fused = fused_dict(cc, truth, pred)
-    look_at = dict()
-    for key in is_split:
-        look_at[key] = not (is_split[key] or is_fused[key])
-
-    # Result Calculation
-    res["Number Fusions"] = nb_fusions(cc, truth)
-    res["Number Splits"] = nb_splits(cc)
-    res["Nb False Positives"] = nb_artefacts(cc)
-    res["Nb False Negatives"] = nb_false_negatives(truth, pred)
-    res["Average Overshoot"], res["Average Undershoot"] = over_undershoot(truth, pred, cc, look_at)
-    res["Av. True Area"] = average_true_area(truth, cc, look_at)
-    res["Av. Pred Area"] = average_pred_area(pred, cc, look_at)
-    res["Nb Considered Cells"] = n_considered_cells(cc, look_at)
-    return res