diff --git a/GUI_main.py b/GUI_main.py new file mode 100644 index 0000000000000000000000000000000000000000..6423ddc73b4e84a50e76eb48ec9e33eeb4dafaea --- /dev/null +++ b/GUI_main.py @@ -0,0 +1,2614 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +This script is the main script used to produce a GUI which should help for +cell segmentation. This script can only read .nd2 files containing the +images of cells, especially it displays for each recorded positions (field +of view) the pictures in the time axis. + +The script opens first a window which allows you to load an nd2 file and +to load or create an hdf file. The hdf file contains all the masks, so if +it is the first the user segments an nd2 file, a new one should be created. +And it can be then loaded for later use. Along with new hdf file, created +by the name entered by the user (say filename), it creates three other hdf +files (filename_predicted.h5, filename_thresholded.h5 and +filename_segmented.h5) these contain all the steps of the NN to get to the +segmented picture. + +After the first window is finished a second one opens, where at each time +index, three pictures +are displayed the t-1 picture, the t picture (current frame which can be +edited) and the t+1 picture. Using the arrows one can navigate through time. +On top of the picture, there is always a mask which is displayed, if no cells +are present in the mask then the mask is blank and the user does not see it. +If one wants to hand anmotate the pictures, one can just start to draw on the +picture using the different functions (New Cell, Add Region, Brush, Eraser, +Save Mask,...) and the informations will be saved in the mask overlayed on +top of the pictures. + +If one wants to segment using a neural network, one can press the +corresponding button (Launch CNN) and select the time range and +the field of views on which the neural network is applied. + +Once the neural network has finished predicting, there are still no visible +masks, but on the field of views and time indices where the NN has been +applied, the threshold and segment buttons are enabled. By checking these +two buttons one can either display the thresholded image of the prediction or +display the segmentation of the thresholded prediction. + +At this stage, one can correct the segmentation of the prediction using +the functions (New Cell, Add Region, etc..) by selecting the Segment +checkbox and then save them using the Save Seg button. +If the user is happy with the segmentation, the Cell Correspondance button +can be clicked. Until then, the cells get random numbers attributed by +the segmentation algorithm. In order to keep track of the cell through time, +the same cells should have the same number between two different time pictures. +This can be (with some errors) achieved by the Cell Correspondance button, +which tries to attribute the same number to corresponding cells in time. +After that, the final mask is saved and it is always visible when you go on +the corresponding picture. This mask can also be corrected using the +usual buttons (because the Cell Correspondance makes also mistakes). + +""" + +import sys +#append all the paths where the modules are stored. Such that this script +#looks into all of these folders when importing modules. +sys.path.append("./unet") +sys.path.append("./disk") +sys.path.append("./icons") +sys.path.append("./init") +sys.path.append("./misc") +import time +import os +import numpy as np + + +# Import everything for the Graphical User Interface from the PyQt5 library. +from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QCheckBox, QLineEdit, QMenu, QAction, QStatusBar +from PyQt5 import QtGui +from PyQt5.QtCore import pyqtSignal, QObject, Qt + +#Import from matplotlib to use it to display the pictures and masks. +from matplotlib.backends.qt_compat import QtCore, QtWidgets, is_pyqt5 +from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar) +from matplotlib.figure import Figure +import matplotlib.pyplot as plt +#import the colormaps function to create a customed colormap scale with 10 +#colors only +from matplotlib import cm +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +#import Path functions to handle the regions drawn by the user. ("add region +#and new cell") +from matplotlib.path import Path + + +#Import all the other python files +#this file handles the interaction with the disk, so loading/saving images +#and masks and it also runs the neural network. +import InteractionDisk_temp as nd + + +#this file contains a dialog window that takes two integers as entry to swap +#two cell values +import ExchangeCellValues as ecv +#this file contains a dialog window which is opened before the main program +#and allows to load the nd2 and hdf files by browsing through the computer. +import DialogFileBrowser as dfb +#this file contains a window that opens to change the value of one cell. It +#is opened as soon as the user presses with the left click on a specific cell. +import ChangeOneCellValue as cocv +#this file contains a dialog window to browse for the excel file where +#all the extracted information on the fluoerscence is written. Or to create a +#new excel file by typing a name in the text box. It is thought to have one +#excel file per field of view. +import DialogDataBrowser as ddb +#this file contains a dialog window where a time range and the field of views +#can be selected to then launch a prediction of the neural network on +#a specific range of pictures. + +import LaunchBatchPrediction as lbp + +#this file initializes all the buttons present in the gui, sets the shortcuts +#to these buttons and also connect the buttons to the function that are +#triggered when the buttons are pressed. +import InitButtons + +#this file contains the layout of the main window so it justs puts the buttons +#and the pictures at the desired position in the main window. +import InitLayout + + +import random + +#import everything needed to write and read excel files. +from openpyxl import load_workbook +from openpyxl import Workbook + + + + +class NavigationToolbar(NavigationToolbar): + """This is the standard matplotlib toolbar but only the buttons + that are of interest for this gui are loaded. These buttons allow + to zoom into the pictures/masks and to navigate in the zoomed picture. + A Home button can be used to set the view back to the original view. + """ + toolitems = [t for t in NavigationToolbar.toolitems if + t[0] in ('Home', 'Pan', 'Zoom','Back', 'Forward')] + +class App(QMainWindow): + """This class creates the main window. + + """ + + def __init__(self, nd2pathstr, hdfpathstr, newhdfstr): + super().__init__() +# initializes the window + + +# id is an integer that gives the id of the connection between the mouseclick method +# and the activation of the button. + +# all these ids are integers which are used to set a connection between +# the button and the function that this button calls. +# There are three of them because it happens that one can trigger three +# different functions with one button. + self.id = 0 + self.id2 = 0 + self.id3 = 0 + + + +# it calls an object of the class Load Image from the InteractionDisk +# file which is used to load images and masks from the nd2 file. To +# initialize this object it needs the path of the nd2 file, of an +# existing hdf file and the name of a new hdf file. If the user has no +# hdf file yet the hdfpathstr will be empty and vice versa if the user +# selects an already existing hdf file. +# It takes all the strings given by the first window +# (called before the main window opens) from the DialogFileBrowser.py +# file. (see at the end of this code) + + self.reader = nd.Reader(hdfpathstr, newhdfstr, nd2pathstr) + + +# these variables are used to create/read/load the excel file used +# to write the fluorescence values extracted. For each field of view, +# the user will be asked each time to create a new xls file for the +# field of view or to load an existing field of view (this is the role +# of the boolean variable) + self.xlsfilename = '' + self.nd2path = nd2pathstr + self.FlagFluoExtraction = False + + + + +# Set the indices for the time axis and the field of view index. These +# indices represent everywhere the current picture (the one that can be +# edited, i.e. the time t frame) + self.Tindex = 0 + self.FOVindex = 0 + +# loading the first images of the cells from the nd2 file + self.currentframe = self.reader.LoadOneImage(self.Tindex,self.FOVindex) + self.nextframe = self.reader.LoadOneImage(self.Tindex+1, self.FOVindex) + self.previousframe = np.zeros([self.reader.sizey, self.reader.sizex]) + + + +# loading the first masks from the hdf5 file + self.mask_curr = self.reader.LoadMask(self.Tindex, self.FOVindex) + self.mask_previous = np.zeros([self.reader.sizey, self.reader.sizex]) + self.mask_next = self.reader.LoadMask(self.Tindex+1, self.FOVindex) + + +# creates a list of all the buttons, which will then be used in order +# to disable all the other buttons at once when one button/function +# is pressed/used in the gui. + self.buttonlist = [] + + +# setting buttons as attributes +# the shortcuts for the buttons, the functions to which they are +# connected to,... are all set up in the ButtonInit file which is called +# in the self.initUI() method below. + + self.button_newcell = QPushButton("New cell") + self.buttonlist.append(self.button_newcell) + + self.button_add_region = QPushButton("Add region") + self.buttonlist.append(self.button_add_region) + + self.button_savemask = QPushButton("Save Mask") + self.buttonlist.append(self.button_savemask) + + self.button_drawmouse = QPushButton('Brush') + self.buttonlist.append(self.button_drawmouse) + + self.button_eraser = QPushButton('Eraser') + self.buttonlist.append(self.button_eraser) + + self.button_exval = QPushButton('Exchange Cell Values') + self.buttonlist.append(self.button_exval) + + self.button_showval = QCheckBox('Show Cell Values') + self.buttonlist.append(self.button_showval) + + self.button_hidemask = QCheckBox('Hide Mask') + self.buttonlist.append(self.button_hidemask) + + self.button_nextframe = QPushButton("Next Time Frame") + self.buttonlist.append(self.button_nextframe) + + self.button_previousframe = QPushButton("Previous Time Frame") + self.buttonlist.append(self.button_previousframe) + + self.button_cnn = QPushButton('Launch CNN') + self.buttonlist.append(self.button_cnn) + + self.button_threshold = QCheckBox('Threshold prediction') + self.buttonlist.append(self.button_threshold) + + self.button_segment = QCheckBox('Segment') + self.buttonlist.append(self.button_segment) + + self.button_cellcorespondance = QPushButton('Cell Correspondance') + self.buttonlist.append(self.button_cellcorespondance) + + self.button_changecellvalue = QPushButton('Change cell value') + self.buttonlist.append(self.button_changecellvalue) + + self.button_extractfluorescence = QPushButton('Extract Fluorescence') + self.buttonlist.append(self.button_extractfluorescence) + + + self.initUI() + + + + + def initUI(self): + """Initializing the widgets contained in the window. + Especially, it creates the widget to plot the + pictures/masks by creating an object of the PlotCanvas class self.m. + Every interaction with the masks or the pictures (loading new + frames/editing the frames/masks) occurs through this class. + + This method initializes all the buttons with the InitButtons file. + It connects the buttons to the functions that they should trigger, + it sets the shortcuts to the buttons, a tool tip, + eventually a message on the status bar when the user hovers + over the button, etc.. + + This function also sets all the layout in the InitLayout file. It + takes and places the widgets (buttons, canvas, toolbar). + + The function initializes a Menu Bar to have a menu which can be + improved later on. + It sets a toolbar of the matplotlib library and hides it. But it allows + to connect to the functions of this toolbar through "homemade" + QPushButtons instead of the ones provided by matplotlib. + Finally, it sets a StatusBar which displays some text to describe + the use of some buttons, or to show that the program is working on + something (running the neural network, loading frames, etc...) + + After all this has been initialized, the program is ready to be used. + + """ + self._main = QtWidgets.QWidget() + self.setCentralWidget(self._main) + +# Here our canvas is created where using matplotlib, +# one can plot data to display the pictures and masks. + self.m = PlotCanvas(self) + +# Initialize all the buttons that are needed and the functions that are +# connected when the buttons are triggered. + InitButtons.Init(self) + InitLayout.Init(self) + +# MENU, TOOLBAR AND STATUS BAR +# creates a menu just in case, some other functions can be added later +# in this menu. + menubar = self.menuBar() + self.fileMenu = menubar.addMenu('File') + self.saveactionmenu = QAction('Save') + self.fileMenu.addAction(self.saveactionmenu) + self.saveactionmenu.triggered.connect(self.ButtonSaveMask) + + +# hide the toolbar and instead of the original buttons of matplotlib, +# QPushbuttons are used and are connected to the functions of the toolbar +# it is than easier to interact with these buttons (for example to +# to disable them and so on..) + self.Nvgtlbar = NavigationToolbar(self.m, self) + self.addToolBar(self.Nvgtlbar) + self.Nvgtlbar.hide() + +# creates a status bar, which displays (or should display) some text +# whenever a function is used. + self.statusBar = QStatusBar() + self.setStatusBar(self.statusBar) + + self.show() + + + def mousePressEvent(self, QMouseEvent): + """this function is implemented just to have the QLineButtons of the + change time index button, setthreshold button and the setsegmentation + button out of focus when the user clicks somewhere + on the gui. (to unfocus the buttons) + """ + self.button_timeindex.clearFocus() + if self.button_SetThreshold.isEnabled(): + self.button_SetThreshold.clearFocus() + + if self.button_SetSegmentation.isEnabled(): + self.button_SetSegmentation.clearFocus() + + + +# CONNECT the functions of the toolbar to our custom QPushbuttons. + def ZoomTlbar(self): + """The button_zoom is connected to the zoom function of the toolbar + already present in the matplotlib library. + + Depending on the buttons that are active or checked, when the zoom + function is used, it does not disable all the buttons. + + If the segment and threshold button are not checked or used + when the zoom button is clicked, it disables all the button + using self.Disable which disables everything except the button passed + in argument (in this case button_zoom). + + + If the zoom button is used while the segment button is checked, + it disables all the buttons (1st elif) except the segment button + but once it is finished (so the zoom button becomes unchecked) + then it enables only the editing buttons (as long as the segment + button is still checked) such as New Cell, Add Region, Eraser, + Brush,etc.. and the other toolbar buttons (3rd elif) + + + If the zoom button is clicked while the threshold button is checked, + it disables all the button except the threshold button (2nd elif). + Once the zoom button is unchecked, it enables the toolbar buttons + (4th elif) + In any other case, it just enables all the buttons again. + + + """ + self.Nvgtlbar.zoom() + + if self.button_zoom.isChecked() and not(self.button_segment.isChecked() or self.button_threshold.isChecked()): + self.Disable(self.button_zoom) + + elif self.button_zoom.isChecked() and self.button_segment.isChecked(): + self.Disable(self.button_zoom) + self.button_segment.setEnabled(True) + + elif self.button_zoom.isChecked() and self.button_threshold.isChecked(): + self.Disable(self.button_zoom) + self.button_threshold.setEnabled(True) + + elif self.button_zoom.isChecked() == False and self.button_segment.isChecked(): + + self.button_pan.setEnabled(True) + self.button_home.setEnabled(True) + self.button_back.setEnabled(True) + self.button_forward.setEnabled(True) + + self.EnableCorrectionsButtons() + + elif self.button_zoom.isChecked() == False and self.button_threshold.isChecked(): + + self.button_pan.setEnabled(True) + self.button_home.setEnabled(True) + self.button_back.setEnabled(True) + self.button_forward.setEnabled(True) + + else: + self.Enable(self.button_zoom) + + def HomeTlbar(self): +# connects the home button to the home function of the matplotlib +# toolbar. It sets the view to the original view (no zoom) + self.Nvgtlbar.home() + + + def BackTlbar(self): +# It calls the back function of the matplotlib toolbar which sets the +# view to the previous one (if the user does several zooms/pans, +# this button allows to go back in the "history of views") + self.Nvgtlbar.back() + + + + def ForwardTlbar(self): +# It calls the forward function of the matplotlib toolbar which sets the +# view to the next one (if the user does several zooms/pans, +# this button allows to go forward in the "history of views" + + self.Nvgtlbar.forward() + + + + def PanTlbar(self): + """The button_pan is connected to the pan function of the toolbar + already present in the matplotlib library. + + Depending on the buttons that are active or checked, when the pan + function is used, it does not disable all the buttons. + + If the segment and threshold button are not checked or used + when the pan button is clicked, it disables all the button + using self.Disable which disables everything except the button passed + in argument (in this case button_pan). + + + If the pan button is used while the segment button is checked, + it disables all the buttons (1st elif) except the segment button + but once it is finished (so the zoom button becomes unchecked) + then it enables only the editing buttons (as long as the segment + button is still checked) such as New Cell, Add Region, Eraser, + Brush,etc.. and the other toolbar buttons (3rd elif) + + + If the pan button is clicked while the threshold button is checked, + it disables all the button except the threshold button (2nd elif). + Once the pan button is unchecked, it enables the toolbar buttons + (4th elif) + In any other case, it just enables all the buttons again. + """ + + self.Nvgtlbar.pan() + + if self.button_pan.isChecked() and not(self.button_segment.isChecked() or self.button_threshold.isChecked()): + self.Disable(self.button_pan) + + elif self.button_pan.isChecked() and self.button_segment.isChecked(): + self.Disable(self.button_pan) + self.button_segment.setEnabled(True) + + elif self.button_pan.isChecked() and self.button_threshold.isChecked(): + self.Disable(self.button_pan) + self.button_threshold.setEnabled(True) + + elif not(self.button_pan.isChecked()) and self.button_segment.isChecked(): + + self.button_zoom.setEnabled(True) + self.button_home.setEnabled(True) + self.button_back.setEnabled(True) + self.button_forward.setEnabled(True) + + self.EnableCorrectionsButtons() + + elif not(self.button_pan.isChecked()) and self.button_threshold.isChecked(): + self.button_zoom.setEnabled(True) + self.button_home.setEnabled(True) + self.button_back.setEnabled(True) + self.button_forward.setEnabled(True) + + + else: + self.Enable(self.button_pan) + + + def ButtonFluo(self): + """This function is called everytime the Extract Fluorescence button is + clicked (self.button_extractfluorescence). + + self.FlagFluoExtraction is boolean which is True when the path to the + excel file has already been loaded into self.xlsfilename. + This pathname changes for each field of view as it is thought to have + one xls file per field of view. + So at the beginning and each time the user changes field of view, + self.FlagFluoExtraction is set to False. + + When it is set to False, this function calls a dialog window where + the user is asked to load an already existing xls file for the current + field of view or to give a name to create a new xls file for + the current field of view. (self.Dialogxls) + + If it set to true, it means that self.xlsfilename contains the path + to the xls file for the current field of view and it is directly given + to the function that writes the fluorescence into the xls file. + (self.ExtractFluo) + """ + + if self.FlagFluoExtraction: + self.ExtractFluo(self.xlsfilename) + else: + self.DialogXls() + + + def DialogXls(self): + """This function creates a dialog window which gives two options to the + user either to load an existing xls file or to give a new name in order + to create a new xls file. + """ +# creates the window + dwind = ddb.FileBrowser() + +# this test is True if the user presses ok in the dialog window, if the +# user presses cancels it returns and nothing happens. + + if dwind.exec_(): + +# read the entry given by the file browser + xlsname = dwind.xlsname + +# reads the entry given by the new filename text field. + newxlsname = dwind.newxlsentry.text() + +# if the string containing the filepath to an existing xls file +# is not empty then it calls directly the function to write the +# data into this existing xls file and sets self.xlsfilename + if xlsname: + + self.xlsfilename = xlsname + self.ExtractFluo(xlsname) + +# if xlsname is empty then it creates a new pathfilename and puts +# the new created xls file into the folder where nd2 is located. +# the string containing the nd2 namepath is split + else: + xlsname = '' + templist = self.nd2path.split('/') + + for k in range(0, len(templist)-1): + + xlsname = xlsname + templist[k] + '/' +# this is the new path/filename + xlsname = xlsname + newxlsname + '.xlsx' + self.xlsfilename = xlsname + +# here as a new name has been given, it means that a new xls file +# should be created, this is done with CreateXls + self.CreateXls(xlsname) +# once there is an existing xls file, it writes in this file +# using self.ExtractFluo. + self.ExtractFluo(xlsname) +# this flag is set to true, for the current field of view each +# time extract fluorescence is clicked it writes in the file located +# at self.xlsfilename. + self.FlagFluoExtraction = True + else: + + return + + + def CreateXls(self, xlsfilename): + """In case there is no xls file existing, here a new one is created + and prepared. For each channel a new sheet is created. + In the first row for each sheet, the time indices are written t = 0, + t = 1, etc... but only every third column. Because in the row below, + three values are extracted 'Total intensity', 'Area' and 'Variance' + at each time index. So three columns for each time index are needed, + for the three data points. + The first column is left empty (starting from third row) because + the cell numbers will be written in there. + """ + +# creates a new xls file using xlwt library. + book = Workbook() + nbrchannels = self.reader.sizec + + for i in range(0,nbrchannels): + sheetname = self.reader.channel_names[i] +# creates a sheet with the name of the corresponding channel. + if i == 0: + sheet = book.active + sheet.title = sheetname + else: + sheet = book.create_sheet(sheetname) + sheet.cell(1,1, 'Cell Number / Time axis') + sheet.cell(2,1, 'labels') + timeaxissize = self.reader.sizet +# start writing the time index at column 1, column 0 is reserved for +# cell numbers. + timecolindex = 2 + + for t in range(1,timeaxissize+1): +# in row 0 the time index is written + sheet.cell(1,timecolindex).value = 't = {}'.format(t-1) +# in row 1, the label of the three data points are written + sheet.cell(2,timecolindex).value = 'Total Intensity' + sheet.cell(2,timecolindex+1).value = 'Total Area' + sheet.cell(2,timecolindex+2).value = 'Mean Intensity' + sheet.cell(2,timecolindex+3).value = 'Variance' +# updates the index, where the next time index should be written + timecolindex = timecolindex + 4 + +# saves the xls file. + book.save(xlsfilename) + + + + + def ExtractFluo(self, xlsfilename): + """This is the function that takes as argument the filepath to the xls + file and writes in the file. + It iterates over the different channels (or the sheets of the file, + each channel has one sheet.), and reads the image corresponding + to the time, field of view and channel index. It reads the already + existing file and makes a copy in which the data will be written in it. + + The first step of calculating the data is to iterate through each + cell/segment of the mask (so each cell is a submatrix of one value + in the matrix of the mask). + For each of these value /cell, the area is extracted as being + the number of pixels corresponding to this cell/value. + (it is known from the microscope settings how to convert + the pixel in area). + The total intensity is just the value of the pixel and it is added over + all the pixels corresonding to the cell/value. + The mean is then calculated as being the total intensity divided by + the number of pixels (which here is equal to the area also). + With the mean it is then possible to calculate the variance of the + signal for one cell/value. + + Then, it is checked if the value of the cell (cell number) already + exists in the first column, if it already exists it continues to + find the column corresponding to the time index where the values + should be written. It sets the flag to True such that it does not + write the cell as new one and adds it at the end of the column + + If the value is not found in the cell number column (new cell or + first time writing in the file), the flag is False, thus it adds the + cell number at the end of the column. + It then saves the xls file. + + """ + +# disables all the buttons except the one passed in argument. + self.Disable(self.button_extractfluorescence) +# shows a message on the status bar to show that the program is working + self.statusBar.showMessage('Extracting the fluorescence...') + +# opens the file to read it. + book = load_workbook(self.xlsfilename) + +# makes a copy of the reading file to write in it the new values. +# wb = xlscopy(readbook) # a writable copy (can't read values out of this, only write to it) + +# iterate over all the channels, so over all the sheets in the file + for channel in range(0, self.reader.sizec): +# loads the picture corresponding to the channel, time index and fov + image = self.reader.LoadImageChannel(self.Tindex, self.FOVindex, channel) + +# loads the sheet to read out corresponding to the current channel + sheet = book.worksheets[channel] +# sheet = readbook.sheet_by_index(channel) + +# this line is here to prevent some errors of streaming into +# the file due to read file which is open (I am not sure about this +# but it is a more or less working solution found on stackoverflow) +# os.remove(xlsfilename) + +# load the sheet corresponding to the current channel to write in it +# writingsheet = wb.get_sheet(channel) + +# this index contains the value of the maximum number of rows in the +# file, it is used to append at the end the cell number column a new +# cell/value, and it is updated each time a new cell is added. + tempidx = sheet.max_row + +# np.unique(array) returns an array which contains all the value +# that appear in self.m.plotmask, so it returns every cell value +# including the background (value 0) present in self.m.plotmask + for val in np.unique(self.m.plotmask): + +# exclude the background, as it is not a cell + if val != 0: + +# this (self.m.plotmask==val).sum() adds one at every pixel +# where self.m.plotmask has the value val + area = (self.m.plotmask == val).sum() +# it sums the value of the pixel in image at the coordinates +# where self.m.plotmask equals val. + tot_intensity = image[self.m.plotmask == val].sum() + +# calculate the mean to use it in the calc. of the variance. + mean = tot_intensity/area + +# create a copy of plotmask because I had weird experiences +# with np.where where it modified sometimes the given array +# (not sure) + temparr = self.m.plotmask.copy() + +# extract the coordinates of the mask matrix where it equals +# the current cell value val + coord = np.where(temparr == val) + +# variable to save the variance. + var = 0 + +# we loop over the coord of the pixel where mask == val + for i in range(0,len(coord[0])): +# extract the intensity at the coordinate + val_intensity = image[coord[0][i], coord[1][i]] +# substract the value of the intensity with the mean, +# and square it. It is then add to the var. + var = var + (val_intensity-mean)*(val_intensity-mean) + +# var is then divided by the number of the pixel +# corresponding to this value (also equal to the area) +# to get the variance. + var = var/area + +# if flag is false it means that the cell number +# corresponding to val is not present in the xls file, first +# column. + flag = False + + +# iterate over all the rows + for row in range(sheet.max_row+1): + +# test if in the first column 0, the number of the cell +# is already present +# if sheet.cell_value(row,0) == str(val): + if sheet.cell(row = row+1, column = 1).value == str(val): + +# if is present, the column corresponding to the +# current time index is by iterating over the cols. + for col in range(sheet.max_column+1): +# test if it is the right column +# if sheet.cell_value(0, col) == 't = {}'.format(self.Tindex): + if sheet.cell(row = 1, column = col+1).value == 't = {}'.format(self.Tindex): +# write in the xls file at the row, col coord + sheet.cell(row+1, col+1, str(tot_intensity)) + sheet.cell(row+1, col+2, str(area)) + sheet.cell(row+1,col+3, str(mean)) + sheet.cell(row+1, col+4, str(var)) + book.save(xlsfilename) + + +# the flag is set to True so that it does +# not execute the code where the cell is +# added in the xls file in a new row. + flag = True + if not flag: +# this lines are executed if a new cell is detected or if +# if it is the first time to write in the file. + for col in range(sheet.max_column+1): + if sheet.cell(row = 1, column = col+1).value == 't = {}'.format(self.Tindex): +# it write the cell value/cell number in the +# column + sheet.cell(tempidx+1,1, str(val)) + +# writes the data extracted before + sheet.cell(tempidx+1,col+1,str(tot_intensity)) + sheet.cell(tempidx+1, col+2, str(area)) + sheet.cell(tempidx+1, col+3, str(mean)) + sheet.cell(tempidx+1, col+4, str(var)) +# it updates the number of rows as a new cell +# has been added, so there is one more row. + tempidx = tempidx + 1 +# save in the file + book.save(xlsfilename) + + +# Enable again all the buttons + self.Enable(self.button_extractfluorescence) + +# clear the message shown in the status bar + self.statusBar.clearMessage() + + + + def LaunchBatchPrediction(self): + """This function is called whenever the button Launch CNN is pressed. + It allows to run the neural network over a time range and selected + field of views. + + It creates a dialog window with two entries, that define the time range + and a list where the user can select the desired fields of view. + + Once it reads all the value, it calls the neural network function + inside of self.PredThreshSeg and it does the prediction of the neural + network, thresholds this prediction and then segments it. + """ +# creates a dialog window from the LaunchBatchPrediction.py file + dlg = lbp.CustomDialog(self) + +# this if tests if the user pressed 'ok' in the dialog window + if dlg.exec_(): + +# it tests if the user has entered some values +# if not it ignores and returns. + if dlg.entry1.text()!= '' and dlg.entry2.text() != '': + +# reads out the entry given by the user and converts the index +# to integers + time_value1 = int(dlg.entry1.text()) + time_value2 = int(dlg.entry2.text()) + + +# it tests if the first value is smaller or equal such that +# time_value1 is the lower range of the time range +# and time_value2 the upper boundary of the range. + if time_value1 <= time_value2 : + +# displays that the neural network is running + self.statusBar.showMessage('Running the neural network...') + +# it iterates in the list of the user-selected fields +# of view, to return the corresponding index, the function +# dlg.listfov.row(item) is used which gives an integer + for item in dlg.listfov.selectedItems(): + +# iterates over the time indices in the range + for t in range(time_value1, time_value2+1): + +# calls the neural network for time t and selected +# fov + if dlg.entry_threshold.text() != '': + thr_val = float(dlg.entry_threshold.text()) + else: + thr_val = None + if dlg.entry_segmentation.text() != '': + seg_val = int(dlg.entry_segmentation.text()) + else: + seg_val = 10 + + self.PredThreshSeg(t, dlg.listfov.row(item), thr_val, seg_val) + + +# once it has iterated over all the fov, the message in +# the status bar is cleared and the buttons are enabled. + self.statusBar.clearMessage() + self.EnableCNNButtons() + + else: + return + else: + return + else: + return + + def PredThreshSeg(self, timeindex, fovindex, thr_val, seg_val): + """ + This function is called in the LaunchBatchPrediction function. + This function calls the neural network function in the + InteractionDisk.py file and then thresholds the result + of the prediction, saves this thresholded prediction. + Then it segments the thresholded prediction and saves the + segmentation. + """ +# launches the neural network + self.reader.LaunchPrediction(timeindex, fovindex) +# thresholds the prediction + self.m.ThresholdMask = self.reader.ThresholdPred(thr_val, timeindex,fovindex) +# saves the thresholded pred. + self.reader.SaveThresholdMask(timeindex, fovindex, self.m.ThresholdMask) +# segments the thresholded pred. + self.m.SegmentedMask = self.reader.Segment(seg_val, timeindex,fovindex) +# saves the segmentation + self.reader.SaveSegMask(timeindex, fovindex, self.m.SegmentedMask) + + + def LaunchPrediction(self): + """This function is not used in the gui, but it can be used to launch + the prediction of one picture, with no thresholding and no segmentation + """ + if not(self.reader.TestPredExisting(self.Tindex, self.FOVindex)): + self.statusBar.showMessage('Running the neural network...') + self.Disable(self.button_cnn) + self.reader.LaunchPrediction(self.Tindex, self.FOVindex) + + self.Enable(self.button_cnn) + + self.button_cnn.setEnabled(False) + self.button_threshold.setEnabled(True) + self.button_segment.setEnabled(True) + self.button_cellcorespondance.setEnabled(True) + self.statusBar.clearMessage() + + def ChangeOneValue(self): + """This function is called when the button Change cell value is + clicked. It displays the instructions on the status bar. + And if the user clicks in the graph where the current mask is displayed + it connects the event of the click (meaning that user has clicked on + one cell) to the function self.DialogBoxChangeOneValue. + This function will then replaces the cell selected by the user with + the click with a new value entered by the user. + """ + +# displaying the instructions on the statusbar + self.statusBar.showMessage('Select one cell using the left click and then enter the desired value in the dialog box') + +# disables all the buttons + self.Disable(self.button_changecellvalue) + +# connects the event "press mouse button" in the matplotlib plot +# (picture) to the function self.DialogBoxChangeOneValue + self.id = self.m.mpl_connect('button_press_event', self.DialogBoxChangeOneValue) + + + + def DialogBoxChangeOneValue(self, event): + """This function is called when the user after the user has selected + the button Change cell value and clicked in the picture to select + the desired cell to change. + + It first deconnects the mouse click event in matplotlib with this + function to not generate any other dialog window. + + It then tests if the click is inside the matplotlib plot (if it is + outside it equals to None) and if it is the current and editable plot + (the one in the middle of the gui, self.m.ax) + + If is true, then it sets the coordinates to int. and creates a dialog + window where the user is asked to type a value to set it to the cell. + + If the user presses ok, it tests if the entry is valid (>0 and not + empty) and looks for the old cell value and replaces it. And then + it updates the plot such that the result of the change can be seen. + """ + +# the function is disconnected from the matplotlib event. + self.m.mpl_disconnect(self.id) + +# test if the button is a left click and if the coordinates +# chosen by the user click is inside of the current matplotlib plot +# which is given by self.m.ax + if event.button == 1 and (event.xdata != None and event.ydata != None) and self.m.ax == event.inaxes: + newx = int(event.xdata) + newy = int(event.ydata) + +# creates a dialog window + dlg = cocv.CustomDialog(self) + +# if the user presses 'ok' in the dialog window it executes the code +# else it does nothing + if dlg.exec(): +# it tests that the user has entered some value, that it is not +# empty and that it is equal or bigger to 0. + if dlg.entry1.text() != '' and int(dlg.entry1.text()) >= 0: +# reads the new value to set and converts it from str to int + value = int(dlg.entry1.text()) + +# self.m.plotmask[newy, newx] the value selected by the user +# self.m.plotmask == self.m.plotmask[newy, newx] +# gives the coordinates where it is equal to the value +# selected by the user. And it replaces it with the new +# value. + self.m.plotmask[self.m.plotmask == self.m.plotmask[newy,newx]] = value +# updates the plot to see the modification. + self.m.updatedata() + +# if the button to show cell values is checked, then it +# replots the cell values + + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + +# enables the button again + self.Enable(self.button_changecellvalue) + +# clears the message in the status bar + self.statusBar.clearMessage() + + + +# the button is a checkable and it has to be unchecked else it seems +# that the button is still in use, because it gets a blue color. + self.button_changecellvalue.setChecked(False) + + + + + def DialogBoxECV(self, s): + """This functions creates from the ExchangeCellValues.py file a + window which takes two integer entries and then swaps the cells having + the given integer values. + """ +# creates a dialog window from the ExchangeCellValues.py file + dlg = ecv.CustomDialog(self) + +# if the user presses 'ok', it executes the code + if dlg.exec_(): + +# it tests if both value to be swapped are not empty. + if dlg.entry1.text()!= '' and dlg.entry2.text() != '': + +# reads out the values and converts it into integers. + value1 = int(dlg.entry1.text()) + value2 = int(dlg.entry2.text()) + +# calls the function which does the swap + self.m.ExchangeCellValue(value1,value2) + +# if the button to display the values of the cell is checked, +# the values are again displayed on the graph after the swap +# of cells. + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + + + else: + return + + + + def SelectChannel(self, index): + """This function is called when the button to select different channels + is used. From the displayed list in the button, the chosen index + corresponnds to the same index in the list of channels from the reader. + So, it sets the default channel with the new index (called index below) + """ + +# This function attributes the FOV chosen by the user corresponding to +# the index in the list of options. + self.reader.default_channel = index +# update the pictures using the same function as the one used to +# change the fields of view. + self.ChangeFOV() + + + + + def SelectFov(self, index): + """This function is called when the button containing the list of + fields od view is used. + The index correspondds to the field of view selected in the list. + + """ + +# This function attributes the FOV chosen by the user corresponding to +# the index in the list of options. First the mask is automatically +# saved. + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) +# the new index is set. + self.FOVindex = index + +# it updates the fov in the plot with the new index. + self.ChangeFOV() + +# the flag of the fluorescence extraction is set to False (such that +# if the user extracts fluorescence data in the new field of view, +# there is a dialog box asking to select the corresponding xls file +# for this field of view. IF there is no data sheet for this fov, the +# user can enter a new name to make a new file.) + self.FlagFluoExtraction = False + + def ChangeFOV(self): + +# it changes the fov or channel according to the choice of the user +# and it updates the plot shown and it initializes the new fov/channel +# at t=0 by default. + +# set the time index to 0 + self.Tindex = 0 + +# load the image and mask for the current plot + self.m.currpicture = self.reader.LoadOneImage(self.Tindex,self.FOVindex) + self.m.plotmask = self.reader.LoadMask(self.Tindex,self.FOVindex) + +# sets the image and the mask to 0 for the previous plot + self.m.prevpicture = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + self.m.prevplotmask = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + +# load the image and the mask for the next plot. + self.m.nextpicture = self.reader.LoadOneImage(self.Tindex+1, self.FOVindex) + self.m.nextplotmask = self.reader.LoadMask(self.Tindex+1, self.FOVindex) + + +# once the images and masks are loaded into the variables, they are +# displaye in the gui. + self.m.UpdateBckgrndPicture() + + +# enables the next frame button in case it was disabled when the +# fov/channel was changed + self.button_nextframe.setEnabled(True) +# disables the previous frame button in case it was active before +# changing fov/channel. + self.button_previousframe.setEnabled(False) + +# updates the title of the plots to display the right time indices +# aboves the plots. + self.UpdateTitleSubplots() + +# if the button to show cell values is active, it shows the values again. + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + +# if the button to hide the mask was checked before changing fov/channel, +# it hides the mask again. + if self.button_hidemask.isChecked(): + self.m.HideMask() + +# the button to set the time index is also set to 0/default again. + self.button_timeindex.setText('') +# enables the neural network buttons if there is already an +# existing prediction for the current image. + self.EnableCNNButtons() + + def ChangeTimeFrame(self): + """This funcion is called whenever the user gives an new time index, + to jump to the new given index, onces "enter" button is pressed. + """ + +# it reads out the text in the button and converts it to an int. + newtimeindex = int(self.button_timeindex.text()) + if newtimeindex >= 0 and newtimeindex <= self.reader.sizet-1: + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) + + self.Tindex = newtimeindex + + if self.Tindex == 0: + self.button_nextframe.setEnabled(True) + self.m.nextpicture = self.reader.LoadOneImage(self.Tindex+1,self.FOVindex) + self.m.nextplotmask = self.reader.LoadMask(self.Tindex+1, self.FOVindex) + + self.m.currpicture = self.reader.LoadOneImage(self.Tindex, self.FOVindex) + self.m.plotmask = self.reader.LoadMask(self.Tindex, self.FOVindex) + + self.m.prevpicture = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + self.m.prevplotmask = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + + + self.m.UpdateBckgrndPicture() + self.button_previousframe.setEnabled(False) + + + elif self.Tindex == self.reader.sizet-1: + self.button_previousframe.setEnabled(True) + self.m.prevpicture = self.reader.LoadOneImage(self.Tindex-1, self.FOVindex) + self.m.prevplotmask = self.reader.LoadMask(self.Tindex-1, self.FOVindex) + + self.m.currpicture = self.reader.LoadOneImage(self.Tindex, self.FOVindex) + self.m.plotmask = self.reader.LoadMask(self.Tindex, self.FOVindex) + + + self.m.nextpicture = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + self.m.nextplotmask = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + + self.m.UpdateBckgrndPicture() + self.button_nextframe.setEnabled(False) + + + else: + + self.button_nextframe.setEnabled(True) + self.button_previousframe.setEnabled(True) + self.m.prevpicture = self.reader.LoadOneImage(self.Tindex-1, self.FOVindex) + self.m.prevplotmask = self.reader.LoadMask(self.Tindex-1, self.FOVindex) + + self.m.currpicture = self.reader.LoadOneImage(self.Tindex, self.FOVindex) + self.m.plotmask = self.reader.LoadMask(self.Tindex, self.FOVindex) + + self.m.nextpicture = self.reader.LoadOneImage(self.Tindex+1,self.FOVindex) + self.m.nextplotmask = self.reader.LoadMask(self.Tindex+1, self.FOVindex) + + self.m.UpdateBckgrndPicture() + + self.UpdateTitleSubplots() + self.button_timeindex.clearFocus() + + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + + if self.button_hidemask.isChecked(): + self.m.HideMask() + self.EnableCNNButtons() + + else: + self.button_timeindex.clearFocus() + return + +# def keyPressEvent(self, event): +## print('keypressevengda211111') +## print(event.key()) +# if not self.button_nextframe.isChecked() and event.key() == Qt.Key_Right: +# self.button_nextframe.setChecked(True) +# self.ForwardTime() +# self.button_nextframe.setChecked(False) +# else: +# print('inside keypressevent') +# event.ignore() + + def CellCorrespActivation(self): + self.Disable(self.button_cellcorespondance) + self.statusBar.showMessage('Doing the cell correspondance') + + if self.Tindex != 0: + self.m.plotmask,_ = self.reader.CellCorrespondance(self.Tindex, self.FOVindex) + self.m.updatedata() + else: + self.m.plotmask = self.reader.LoadSeg(self.Tindex, self.FOVindex) + self.m.updatedata() + + self.Enable(self.button_cellcorespondance) + self.button_cellcorespondance.setChecked(False) + self.statusBar.clearMessage() + + def SegmentBoxCheck(self): + + if self.button_segment.isChecked(): + + self.Disable(self.button_segment) + self.EnableCorrectionsButtons() + self.m.SegmentedMask = self.reader.LoadSeg(self.Tindex, self.FOVindex) + self.m.tempplotmask = self.m.plotmask.copy() + self.m.plotmask = self.m.SegmentedMask.copy() + self.m.currmask.set_data((self.m.SegmentedMask%10 + 1)*(self.m.SegmentedMask != 0)) + self.m.ax.draw_artist(self.m.currplot) + self.m.ax.draw_artist(self.m.currmask) + self.m.update() + self.m.flush_events() +# update the graph + + self.button_SetSegmentation.setEnabled(True) + self.button_savesegmask.setEnabled(True) + else: + self.m.SegmentedMask = self.m.plotmask.copy() + self.m.plotmask = self.m.tempplotmask.copy() + self.m.updatedata() + self.button_SetSegmentation.setEnabled(False) + self.button_savesegmask.setEnabled(False) + self.Enable(self.button_segment) + + def SegmentThresholdedPredMask(self): + + segparamvalue = int(self.button_SetSegmentation.text()) + self.m.plotmask = self.reader.Segment(segparamvalue, self.Tindex,self.FOVindex) + self.m.currmask.set_data((self.m.plotmask%10 + 1)*(self.m.plotmask != 0)) + self.m.ax.draw_artist(self.m.currplot) + self.m.ax.draw_artist(self.m.currmask) + self.m.update() + self.m.flush_events() +# self.m.SegmentedMask = self.reader.Segment(segparamvalue, self.Tindex, self.FOVindex) +# update the plots to display the segmentation view + + def ButtonSaveSegMask(self): + """saves the segmented mask + """ + self.reader.SaveSegMask(self.Tindex, self.FOVindex, self.m.plotmask) + + + + + def ThresholdBoxCheck(self): + """if the buttons is checked it shows the thresholded version of the + prediction, if it is not available it justs displays a null array. + The buttons for the setting a threshold a value and to save it are then + activated once this button is enabled. + """ + if self.button_threshold.isChecked(): + self.Disable(self.button_threshold) + self.m.ThresholdMask = self.reader.LoadThreshold(self.Tindex, self.FOVindex) + + self.m.currmask.set_data(self.m.ThresholdMask) + self.m.ax.draw_artist(self.m.currplot) + self.m.ax.draw_artist(self.m.currmask) + self.m.update() + self.m.flush_events() + + +# update the gra + + self.button_SetThreshold.setEnabled(True) + self.button_savethresholdmask.setEnabled(True) + else: + self.m.updatedata() + + self.button_SetThreshold.setEnabled(False) + self.button_savethresholdmask.setEnabled(False) + self.Enable(self.button_threshold) + + def ThresholdPrediction(self): + + thresholdvalue = float(self.button_SetThreshold.text()) + + self.m.ThresholdMask = self.reader.ThresholdPred(thresholdvalue, self.Tindex,self.FOVindex) + self.m.currmask.set_data(self.m.ThresholdMask) + self.m.ax.draw_artist(self.m.currplot) + self.m.ax.draw_artist(self.m.currmask) + self.m.update() + self.m.flush_events() +# self.m.ThresholdMask = self.reader.ThresholdPred(thresholdvalue, self.Tindex, self.FOVindex) +# update the plots to display the thresholded view + + def ButtonSaveThresholdMask(self): + """saves the thresholed mask + """ +# pass + self.reader.SaveThresholdMask(self.Tindex, self.FOVindex, self.m.ThresholdMask) + + + def ChangePreviousFrame(self): + + """This function is called when the previous frame buttons is pressed + and it tests if the buttons is enabled and if so it calls the + BackwardTime() function. It should avoid the let the user do multiple + clicks and that the function is then called afterwards several times, + once the frames and masks of the current time index have been loaded. + """ + + if self.button_previousframe.isEnabled(): + self.button_previousframe.setEnabled(False) + +# self.button_nextframe.disconnect() +# self.button_nextframe.setShortcut('') + +# self.button_nextframe.setEnabled(False) +# print(self.button_nextframe.isEnabled()) +# self.button_nextframe.setChecked(False) +# self.Testbool = False + self.BackwardTime() +# self.button_nextframe.setShortcut(Qt.Key_Right) +# self.button_nextframe.pressed.connect(self.Test) +# self.button_nextframe.setChecked(False) + if self.Tindex >0: + self.button_previousframe.setEnabled(True) +# self.button_nextframe.pressed.connect(self.Test) + +# self.button_nextframe.setChecked(False) + +# self.Testbool = True + else: +# print('jamais la dedans?') + return + + + def ChangeNextFrame(self): + + """This function is called when the next frame buttons is pressed + and it tests if the buttons is enabled and if so it calls the + ForwardTime() function. It should avoid the let the user do multiple + clicks and that the function is then called afterwards several times, + once the frames and masks of the current time index have been loaded. + """ +# self.button_nextframe.setShortcutEnabled(False) + if self.button_nextframe.isEnabled(): + self.button_nextframe.setEnabled(False) + +# self.button_nextframe.disconnect() +# self.button_nextframe.setShortcut('') + +# self.button_nextframe.setEnabled(False) +# print(self.button_nextframe.isEnabled()) +# self.button_nextframe.setChecked(False) +# self.Testbool = False + self.ForwardTime() +# self.button_nextframe.setShortcut(Qt.Key_Right) +# self.button_nextframe.pressed.connect(self.Test) +# self.button_nextframe.setChecked(False) + + self.button_nextframe.setEnabled(True) +# self.button_nextframe.pressed.connect(self.Test) + +# self.button_nextframe.setChecked(False) + +# self.Testbool = True + else: + return +# print('jamais la dedans?') +# if QKeyEvent.key() == Qt.Key_Right: +# QKeyEvent.ignore() + + + def ForwardTime(self): + + """This function switches the frame in forward time index. And it tests + several conditions if t == lastTimeIndex-1, because then the next frame + button has to be disabled. It also tests if the show value of cells + button and hidemask are active in order to hide/show the mask or to + show the cell values. + """ +# print(self.Tindex) +# the t frame is defined as the currently shown frame on the display. +# If the button "Next time frame" is pressed, this function is called + self.statusBar.showMessage('Loading the next frame...') +# self.button_nextframe.setEnabled(False) +# self.button_nextframe.disconnect() + self.Disable(self.button_nextframe) + + + if self.Tindex + 1 < self.reader.sizet - 1 : + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) + + self.m.prevpicture = self.m.currpicture.copy() + self.m.prevplotmask = self.m.plotmask.copy() + + self.m.currpicture = self.m.nextpicture.copy() + self.m.plotmask = self.m.nextplotmask.copy() + + + self.m.nextpicture = self.reader.LoadOneImage(self.Tindex+2, self.FOVindex) + self.m.nextplotmask = self.reader.LoadMask(self.Tindex+2, self.FOVindex) + self.m.UpdateBckgrndPicture() + + if self.Tindex + 1 == 1: + self.button_previousframe.setEnabled(True) + + + +# + else: + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) + + self.m.prevpicture = self.m.currpicture.copy() + self.m.prevplotmask = self.m.plotmask.copy() + self.m.currpicture = self.m.nextpicture.copy() + self.m.plotmask = self.m.nextplotmask.copy() + + + + self.m.nextpicture = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + self.m.nextplotmask = np.zeros([self.reader.sizey,self.reader.sizex], dtype = np.uint16) + self.m.UpdateBckgrndPicture() + + self.button_nextframe.setEnabled(False) + + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + + + + + + self.Tindex = self.Tindex+1 + self.UpdateTitleSubplots() + + if self.button_hidemask.isChecked(): + self.m.HideMask() + + self.Enable(self.button_nextframe) + +# self.button_nextframe.setChecked(False) + + self.statusBar.clearMessage() +# if self.Tindex < self.reader.sizet - 1 : +# self.button_nextframe.setEnabled(True) + self.button_timeindex.setText(str(self.Tindex)) + + + def BackwardTime(self): + + """This function switches the frame in backward time index. And it + several conditions if t == 1, because then the button previous frame has to + be disabled. It also tests if the show value of cells button and + hidemask are active in order to hide/show the mask or to show the cell + values. + """ +# print(self.Tindex) +# the t frame is defined as the currently shown frame on the display. +# If the button "Previous time frame" is pressed, this function is called + self.statusBar.showMessage('Loading the previous frame...') +# self.button_previousframe.setEnabled(False) +# self.button_previousframe.disconnect() + self.Disable(self.button_previousframe) + if self.Tindex == 1: + + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) + + self.m.nextpicture = self.m.currpicture.copy() + self.m.nextplotmask = self.m.plotmask.copy() + + self.m.currpicture = self.m.prevpicture.copy() + self.m.plotmask = self.m.prevplotmask.copy() + + self.m.prevpicture = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + self.m.prevplotmask = np.zeros([self.reader.sizey, self.reader.sizex], dtype = np.uint16) + + + self.m.UpdateBckgrndPicture() + + + self.button_previousframe.setEnabled(False) + + + + else: + + + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) + + self.m.nextpicture = self.m.currpicture.copy() + self.m.nextplotmask = self.m.plotmask.copy() + + self.m.currpicture = self.m.prevpicture.copy() + self.m.plotmask = self.m.prevplotmask.copy() + + self.m.prevpicture = self.reader.LoadOneImage(self.Tindex-2, self.FOVindex) + self.m.prevplotmask = self.reader.LoadMask(self.Tindex-2, self.FOVindex) + + self.m.UpdateBckgrndPicture() + if self.Tindex-1 == self.reader.sizet-2: + self.button_nextframe.setEnabled(True) + + + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + + + + +# self.button_previousframe.clicked.connect(self.BackwardTime) + + self.Tindex = self.Tindex-1 + self.UpdateTitleSubplots() + + if self.button_hidemask.isChecked(): + self.m.HideMask() + + self.Enable(self.button_previousframe) + + + if self.Tindex > 0: + self.button_previousframe.setEnabled(True) + +# self.button_previousframe.setChecked(False) + self.statusBar.clearMessage() + self.button_timeindex.setText(str(self.Tindex)) + + def MouseDraw(self): + + """ + This function is called whenever the brush or the eraser button is + pressed. On the first press event it calls the self.m.OneClick, which + tests whether it is a right click or a left click. If it a right click + it assigns the value of the pixel which has been right clicked + to self.cellval, meaning that the next drawn pixels will be set to this + value. + If it is left clicked, then it draws a 3x3 square with the current + value of self.cellval. + If after left clicking you drag the mouse, then you start drawing + using the mouse and it stops once you release the left click. + + Same for the eraser button, it sets directly the value of self.cellval + to 0. + """ + + + if self.button_drawmouse.isChecked(): + + self.statusBar.showMessage('Drawing using the brush, right click to set a value...') + + self.Disable(self.button_drawmouse) + self.m.tempmask = self.m.plotmask.copy() + + self.id2 = self.m.mpl_connect('button_press_event', self.m.OneClick) + + + + self.id = self.m.mpl_connect('motion_notify_event', self.m.PaintBrush) + + self.id3 = self.m.mpl_connect('button_release_event', self.m.ReleaseClick) + + + pixmap = QtGui.QPixmap('./icons/brush2.png') + cursor = QtGui.QCursor(pixmap, 1,1) + QApplication.setOverrideCursor(cursor) + + elif self.button_eraser.isChecked(): + + self.statusBar.showMessage('Erasing by setting the values to 0...') + self.Disable(self.button_eraser) + + self.m.tempmask = self.m.plotmask.copy() + + self.m.cellval = 0 + self.id2 = self.m.mpl_connect('button_press_event', self.m.OneClick) + self.id = self.m.mpl_connect('motion_notify_event', self.m.PaintBrush) + + self.id3 = self.m.mpl_connect('button_release_event', self.m.ReleaseClick) + + + pixmap = QtGui.QPixmap('./icons/eraser.png') + cursor = QtGui.QCursor(pixmap, 1,1) + QApplication.setOverrideCursor(cursor) + + else: + self.m.mpl_disconnect(self.id3) + self.m.mpl_disconnect(self.id2) + self.m.mpl_disconnect(self.id) + QApplication.restoreOverrideCursor() + self.Enable(self.button_drawmouse) + self.Enable(self.button_eraser) + + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + self.statusBar.clearMessage() + + + def UpdateTitleSubplots(self): + """This function updates the title of the plots according to the + current time index. So it called whenever a frame or a fov is changed. + """ + if self.Tindex == 0: + + self.m.titlecurr.set_text('Time index {}'.format(self.Tindex)) + self.m.titleprev.set_text('No frame {}'.format('')) + self.m.titlenext.set_text('Next Time index {}'.format(self.Tindex+1)) + + +# self.m.ax.set_title('Time index {}'.format(self.Tindex)) +# self.m.ax2.set_title('No frame {}'.format('')) +# self.m.ax3.set_title('Next Time index {}'.format(self.Tindex+1)) +# self.m.update() +# self.m.flush_events() + self.m.draw() + elif self.Tindex == self.reader.sizet-1: + + self.m.titlecurr.set_text('Time index {}'.format(self.Tindex)) + self.m.titleprev.set_text('Previous time index {}'.format(self.Tindex-1)) + self.m.titlenext.set_text('No frame {}'.format('')) + + + + +# self.m.ax.set_title('Time index {}'.format(self.Tindex)) +# self.m.ax2.set_title('Previous time index {}'.format(self.Tindex-1)) +# self.m.ax3.set_title('No frame {}'.format('')) +# self.m.update() +# self.m.flush_events() + self.m.draw() + else: + self.m.titlecurr.set_text('Time index {}'.format(self.Tindex)) + self.m.titleprev.set_text('Previous time index {}'.format(self.Tindex-1)) + self.m.titlenext.set_text('Next Time index {}'.format(self.Tindex+1)) + + +# self.m.ax.set_title('Time index {}'.format(self.Tindex)) +# self.m.ax2.set_title('Previous time index {}'.format(self.Tindex-1)) +# self.m.ax3.set_title('Next Time index {}'.format(self.Tindex+1)) +# self.m.update() +# self.m.flush_events() + self.m.draw() + + def ClickNewCell(self): + """ + this method is called when the button New Cell is clicked. If the button + state corresponds to True (if is activated) then it connects the mouse + clicks on the pyqt window to the canvas (so to the matplolib figure). + The connection has an "id" which is given by the integer self.id + After the connections is made, it calls the Disable function with argument 0 + which turns off the other button(s). + + If the button is clicked but it is deactivated then it disconnects the + connection between the canvas and the window (the user can not interact + with the plot anymore). + Storemouseclicks is a list corresponding to the coordinates of all mouse + clicks between the activation and the deactivation of the button. + So if it is empty, it does not draw anything because no clicks + were registered. + But if it has some coordinates, it will draw a polygon where the vertices + are the coordinates of all the mouseclicks. + Once the figure has been updated with a new polygon, the other button(s) + are again enabled. + + """ + if self.button_newcell.isChecked(): + self.statusBar.showMessage('Draw a new cell...') + self.m.tempmask = self.m.plotmask.copy() + self.id = self.m.mpl_connect('button_press_event', self.m.MouseClick) + self.Disable(self.button_newcell) + + + + else: + + self.m.mpl_disconnect(self.id) + if self.m.storemouseclicks and self.TestSelectedPoints(): + self.m.DrawRegion(True) + else: + self.m.updatedata() + self.Enable(self.button_newcell) + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + self.statusBar.clearMessage() + + + def TestSelectedPoints(self): + + """This function is just used to catch an exception, when the new cell + or the add region function is called. If all the dots drawn by the user + are located on one line (horizontal or vertical) the DrawRegion + function calls a method to create a polygon and + it can not make a polygon out of straight line so it gives an error. + In order to prevent this error, this function avoids to attempt to draw + by returning False if the square are all on one line. + """ + + + + allx = list(np.array(self.m.storemouseclicks)[:,0]) + ally = list(np.array(self.m.storemouseclicks)[:,1]) + + resultx = all(elem == allx[0] for elem in allx) + resulty = all(elem == ally[0] for elem in ally) + + if resultx or resulty: + return False + else: + return True + + + + + def clickmethod(self): + """ + this method is called when the button Add region is clicked. If the button + state corresponds to True (if is activated) then it connects the mouse + clicks on the pyqt window to the canvas (so to the matplolib figure). + The connection has an "id" which is given by the integer self.id + After the connections is made, it calls the Disable function with argument 1 + which turns off the other button(s). + + If the button is clicked and it is deactivated then it disconnects the + connection between the canvas and the window (the user can not interact + with the plot anymore). + Storemouseclicks is a list corresponding to the coordinates of all mouse + clicks between the activation and the deactivation of the button. + So if it is empty, it does not draw anything because no clicks + were registered. + But if it has some coordinates, it will draw a polygon where the vertices + are the coordinates of all the mouseclicks. + Once the figure has been updated with a new polygon, the other button(s) + are again enabled. + + """ + if self.button_add_region.isChecked(): + self.statusBar.showMessage('Adding a region to an existing cell...') + self.m.tempmask = self.m.plotmask.copy() + self.id = self.m.mpl_connect('button_press_event', self.m.MouseClick) + self.Disable(self.button_add_region) + + else: + + self.m.mpl_disconnect(self.id) + +# test if the list is not empty and if the dots are not all in the same line + if self.m.storemouseclicks and self.TestSelectedPoints(): + + self.m.DrawRegion(False) + + else: + self.m.updatedata() + self.Enable(self.button_add_region) + + if self.button_showval.isChecked(): + self.m.ShowCellNumbersCurr() + self.m.ShowCellNumbersNext() + self.m.ShowCellNumbersPrev() + + self.statusBar.clearMessage() + + + + + + def Enable(self, button): + + """ + this functions turns on buttons all the buttons, depending on the time + index. (next and previous buttons should not be turned on if t = 0 + or t = lasttimeindex) + """ + if self.button_segment.isChecked(): + self.EnableCorrectionsButtons() + self.button_home.setEnabled(True) + self.button_zoom.setEnabled(True) + self.button_pan.setEnabled(True) + self.button_back.setEnabled(True) + self.button_forward.setEnabled(True) + else: + for k in range(0, len(self.buttonlist)): + if button != self.buttonlist[k]: + self.buttonlist[k].setEnabled(True) + + if self.Tindex == 0: + self.button_previousframe.setEnabled(False) + + if self.Tindex == self.reader.sizet-1: + self.button_nextframe.setEnabled(False) + + self.EnableCNNButtons() + + + def Disable(self, button): + + """ + this functions turns off all the buttons except the one given in + argument. + """ + flag = False + if button == self.button_add_region or button == self.button_newcell or button == self.button_exval or button == self.button_changecellvalue or button == self.button_drawmouse or button == self.button_eraser: + if self.button_segment.isChecked(): + flag = True + + + for k in range(0,len(self.buttonlist)): + if button != self.buttonlist[k]: + self.buttonlist[k].setEnabled(False) + if flag: + self.button_segment.setEnabled(True) + + if button == self.button_segment or button == self.button_threshold: + self.button_home.setEnabled(True) + self.button_zoom.setEnabled(True) + self.button_pan.setEnabled(True) + self.button_back.setEnabled(True) + self.button_forward.setEnabled(True) + + def EnableCNNButtons(self): + + if self.reader.TestPredExisting(self.Tindex, self.FOVindex): +# self.button_cnn.setEnabled(False) + self.button_threshold.setEnabled(True) + self.button_segment.setEnabled(True) + self.button_cellcorespondance.setEnabled(True) + self.button_extractfluorescence.setEnabled(True) + else: +# self.button_cnn.setEnabled(True) + self.button_threshold.setEnabled(False) + self.button_segment.setEnabled(False) + self.button_cellcorespondance.setEnabled(False) + self.button_extractfluorescence.setEnabled(False) + + def EnableCorrectionsButtons(self): + self.button_newcell.setEnabled(True) + self.button_add_region.setEnabled(True) + self.button_drawmouse.setEnabled(True) + self.button_eraser.setEnabled(True) + self.button_exval.setEnabled(True) + self.button_changecellvalue.setEnabled(True) + self.button_showval.setEnabled(True) + + def DisableCorrectionsButtons(self): + self.button_newcell.setEnabled(False) + self.button_add_region.setEnabled(False) + self.button_drawmouse.setEnabled(False) + self.button_eraser.setEnabled(False) + self.button_exval.setEnabled(False) + self.button_changecellvalue.setEnabled(False) + self.button_showval.setEnabled(False) + + def ButtonSaveMask(self): + """ + When this function is called, it saves the current mask + (self.m.plotmask) + """ + + self.reader.SaveMask(self.Tindex, self.FOVindex, self.m.plotmask) + + + +class PlotCanvas(FigureCanvas): + + def __init__(self, parent=None): + """this class defines the canvas. It initializes a figure, which is then + used to plot our data using imshow. + + """ + + +# define three subplots corresponding to the previous, current and next +# time index. + fig, (self.ax2, self.ax, self.ax3) = plt.subplots(1,3, sharex = True, sharey = True) + + FigureCanvas.__init__(self, fig) + self.setParent(parent) + +# this is some mambo jambo. + FigureCanvas.setSizePolicy(self, + QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvas.updateGeometry(self) + +# the self.currpicture attribute takes the original data and will then +# contain the updates drawn by the user. + + self.currpicture = parent.currentframe + + self.prevpicture = parent.previousframe + + self.nextpicture = parent.nextframe + + self.plotmask = parent.mask_curr + + self.prevplotmask = parent.mask_previous + + self.nextplotmask = parent.mask_next + + self.tempmask = self.plotmask.copy() + + self.tempplotmask = self.plotmask.copy() + + self.ThresholdMask = np.zeros([parent.reader.sizey, parent.reader.sizex], dtype = np.uint16) + self.SegmentedMask = np.zeros([parent.reader.sizey, parent.reader.sizex], dtype = np.uint16) + +# this line is just here to not attribute a zero value to the plot +# because if so, then it does not update the plot and it stays blank. +# (it is unclear why..if someone finds a better solution) + self.prevpicture = self.currpicture.copy() + + + self.currplot, self.currmask = self.plot(self.currpicture, self.plotmask, self.ax) + + self.previousplot, self.previousmask = self.plot(self.prevpicture, self.prevplotmask, self.ax2) + self.prevpicture = np.zeros([parent.reader.sizey, parent.reader.sizex], dtype = np.uint16) + self.prevplotmask = np.zeros([parent.reader.sizey, parent.reader.sizex], dtype =np.uint16) + +# print('set visible') +# self.ax2.set_visible(False) + + self.nextplot, self.nextmask = self.plot(self.nextpicture, self.nextplotmask, self.ax3) + + self.previousplot.set_data(self.prevpicture) + self.previousmask.set_data((self.prevplotmask%10+1)*(self.prevplotmask != 0)) + + self.ax2.draw_artist(self.previousplot) + self.ax2.draw_artist(self.previousmask) + self.update() + self.flush_events() + + + self.titlecurr = self.ax.set_title('Time index {}'.format(parent.Tindex)) + self.titleprev = self.ax2.set_title('No frame {}'.format('')) + self.titlenext = self.ax3.set_title('Next Time index {}'.format(parent.Tindex+1)) + + +# these variables are just set to test the states of the buttons +# (button turned on or off, etc..) of the buttons in the methods +# used in this class. + self.button_showval_check = parent.button_showval + self.button_newcell_check = parent.button_newcell + self.button_add_region_check = parent.button_add_region + self.button_drawmouse_check = parent.button_drawmouse + self.button_eraser_check = parent.button_eraser + self.button_hidemask_check = parent.button_hidemask + + +# It will plot for the first time and return the imshow function + + self.currmask.set_clim(0, 10) + self.previousmask.set_clim(0,10) + self.nextmask.set_clim(0,10) + +# This attribute is a list which stores all the clicks of the mouse. + self.storemouseclicks = [] + +# This attribute is used to store the square where the mouse has been +# in order than to draw lines (Paintbrush function) + self.storebrushclicks = [[False,False]] + +# self.cellval is the variable which sets the value to the pixel +# whenever something is drawn. + self.cellval = 0 +# self.store_values = [] + +# These are the codes used to create a polygon in the new cell/addregion +# functions, which should be fed into the Path function + self.codes_drawoneline = [Path.MOVETO, Path.LINETO] + +# These are lists storing all the annotations which are used to +# show the values of the cells on the plots. + self.ann_list = [] + self.ann_list_prev = [] + self.ann_list_next = [] + + + + + + def ExchangeCellValue(self, val1, val2): + """Swaps the values of the cell between two clusters each representing + one cell. This method is called after the user has entered + values in the ExchangeCellValues window. + """ + + + if (val1 in self.plotmask) and (val2 in self.plotmask): + indices = np.where(self.plotmask == val1) + self.plotmask[self.plotmask == val2] = val1 + for i in range(0,len(indices[0])): + self.plotmask[indices[0][i], indices[1][i]] = val2 + self.updatedata() + + else: +# print('No cell values found corresponding to the entered value') + return + + def ReleaseClick(self, event): + """This method is called from the brush button when the mouse is + released such that the last coordinate saved when something is drawn + is set to zero. Because otherwise, if the user starts drawing somewhere + else, than a straight line is draw between the last point of the + previous mouse drawing/dragging and the new one which then starts. + """ + if self.ax == event.inaxes: + self.storebrushclicks[0] = [False, False] + + def OneClick(self, event): + """This method is called when the Brush button is activated. And + sets the value of self.cellval if the click is a right click, or draws + a square if the click is a left click. (so if the user does just left + click but does not drag, there will be only a square which is drawn ) + """ + if event.button == 3 and (event.xdata != None and event.ydata != None) and (not self.button_eraser_check.isChecked()) and self.ax == event.inaxes: + tempx = int(event.xdata) + tempy = int(event.ydata) + self.cellval = self.plotmask[tempy, tempx] + self.storebrushclicks[0] = [False, False] + + elif event.button == 1 and (event.xdata != None and event.ydata != None) and self.ax == event.inaxes: + tempx = int(event.xdata) + tempy = int(event.ydata) + self.plotmask[tempy:tempy+3, tempx:tempx+3] = self.cellval + self.storebrushclicks[0] = [tempx,tempy] + + self.updatedata() + + else: + return + + + def PaintBrush(self, event): + """PantBrush is the method to paint using a "brush" and it is based + on the mouse event in matplotlib "motion notify event". However it can + not record every pixel that the mouse has hovered over (it is too fast). + So, in order to not only draw points (happens when the mouse is dragged + too quickly), these points are interpolated here with lines. + """ + + if event.button == 1 and (event.xdata != None and event.ydata != None) and self.ax == event.inaxes: + + newx = int(event.xdata) + newy = int(event.ydata) +# when a new cell value is set, there is no point to interpolate, to +# draw a line between the points. + if self.storebrushclicks[0][0] == False : + self.plotmask[newy:newy+3,newx:newx+3] = self.cellval + self.storebrushclicks[0] = [newx,newy] + else: + oldx = self.storebrushclicks[0][0] + oldy = self.storebrushclicks[0][1] + + if newx != oldx: + + slope = (oldy-newy)/(oldx-newx) + offset = (newy*oldx-newx*oldy)/(oldx-newx) + + if newx > oldx: + for xtemp in range(oldx+1, newx+1): + ytemp = int(slope*xtemp + offset) + self.plotmask[ytemp:ytemp + 3, xtemp:xtemp+3] = self.cellval + else: + for xtemp in range(oldx-1,newx-1,-1): + ytemp = int(slope*xtemp + offset) + self.plotmask[ytemp:ytemp+3, xtemp:xtemp+3] = self.cellval + else: + if newy > oldy: + for ytemp in range(oldy+1,newy+1): + self.plotmask[ytemp:ytemp+3, newx:newx+3] = self.cellval + else: + for ytemp in range(oldy-1,newy-1,-1): + self.plotmask[ytemp:ytemp+3, newx:newx+3] = self.cellval + + self.storebrushclicks[0][0] = newx + self.storebrushclicks[0][1] = newy + + self.updatedata() + + + + def MouseClick(self,event): + """This function is called whenever, the add region or the new cell + buttons are active and the user clicks on the plot. For each + click on the plot, it records the coordinate of the click and stores + it. When the user deactivate the new cell or add region button, + all the coordinates are given to the DrawRegion function (if they + do not all lie on the same line) and out of the coordinates, it makes + a polygon. And then draws inside of this polygon by setting the pixels + to the self.cellval value. + """ +# button == 1 corresponds to the left click. + + if event.button == 1 and (event.xdata != None and event.ydata != None) and self.ax == event.inaxes: + +# extract the coordinate of the click inside of the matplotlib figure +# and then takes the integer part + + newx = int(event.xdata) + newy = int(event.ydata) + +# print(newx,newy)m +# stores the coordinates of the click + + self.storemouseclicks.append([newx, newy]) +# draws in the figure a small square (4x4 pixels) to +# visualize where the user has clicked + self.updateplot(newx, newy) + + + def DefineColormap(self, Ncolors): + """Define a new colormap by assigning 10 values of the jet colormap + such that there are only colors for the values 0-10 and the values >10 + will be treated with a modulo operation (updatedata function) + """ + jet = cm.get_cmap('jet', Ncolors) + colors = [] + for i in range(0,Ncolors): + if i==0 : +# set background transparency to 0 + temp = list(jet(i)) + temp[3]= 0.0 + colors.append(tuple(temp)) + else: + colors.append(jet(i)) + colormap = ListedColormap(colors) + return colormap + + def plot(self, picture, mask, ax): + """this function is called for the first time when all the subplots + are drawn. + """ + +# Define a new colormap with 20 colors. + newcmp = self.DefineColormap(21) + ax.axis("off") + + self.draw() + return ax.imshow(picture, interpolation= 'None', origin = 'upper', cmap = 'gray_r'), ax.imshow((mask%10+1)*(mask != 0), origin = 'upper', interpolation = 'None', alpha = 0.2, cmap = newcmp) + + + def UpdateBckgrndPicture(self): + """this function can be called to redraw all the pictures and the mask, + so it is called whenever a time index is entered by the user and + the corresponding pictures and masks are updated. And then they are + drawn here. + When the user changes time frame using the next or previous time frame + buttons, it is also this function which is called. + When the user changes the field of view, it is also + this function which finally draws all the plots. + + """ + + self.currplot.set_data(self.currpicture) + self.currplot.set_clim(np.amin(self.currpicture), np.amax(self.currpicture)) + self.currmask.set_data((self.plotmask%10+1)*(self.plotmask!=0)) + self.ax.draw_artist(self.currplot) + self.ax.draw_artist(self.currmask) + + + self.previousplot.set_data(self.prevpicture) + self.previousplot.set_clim(np.amin(self.prevpicture), np.amax(self.prevpicture)) + self.previousmask.set_data((self.prevplotmask%10+1)*(self.prevplotmask != 0)) + + + self.ax2.draw_artist(self.previousplot) + self.ax2.draw_artist(self.previousmask) + + self.nextplot.set_data(self.nextpicture) + self.nextplot.set_clim(np.amin(self.nextpicture), np.amax(self.nextpicture)) + self.nextmask.set_data((self.nextplotmask % 10 +1 )*(self.nextplotmask != 0)) + self.ax3.draw_artist(self.nextplot) + self.ax3.draw_artist(self.nextmask) + + + self.update() + self.flush_events() + + + def updatedata(self, flag = True): + + + """ + In order to just display the cells so regions with value > 0 + and also to assign to each of the cell values one color, + the modulo 10 of the value is take and we add 1, to distinguish + the values of 10,20,30,... from the background (although the bckgrnd + gets with the addition the value 1) and the result of the + modulo is multiplied with a matrix containing a False value for the + background coordinates, setting the background to 0 again. + """ + if flag: + self.currmask.set_data((self.plotmask%10+1)*(self.plotmask!=0)) + else : + self.currmask.set_data((self.tempmask%10+1)*(self.tempmask!=0)) + + + +# show the updates by redrawing the array using draw_artist, it is faster +# to use as it only redraws the array itself, and not everything else. + + self.ax.draw_artist(self.currplot) + self.ax.draw_artist(self.currmask) +# self.ax.text(500,500,'test') + self.update() + self.flush_events() +# +# if self.button_showval_check.isChecked(): +# self.ShowCellNumbersCurr() +# self.ShowCellNumbersNext() +# self.ShowCellNumbersPrev() + + def Update3Plots(self): + """This function is just used to draw the update the masks on + the three subplots. It is only used by the Hidemask function. + To "show" the masks again when the button is unchecked. + """ + +# self.currmask.set_data((self.plotmask%10+1)*(self.plotmask!=0)) + self.ax.draw_artist(self.currplot) + self.ax.draw_artist(self.currmask) + + self.ax2.draw_artist(self.previousplot) + self.ax2.draw_artist(self.previousmask) + + self.ax3.draw_artist(self.nextplot) + self.ax3.draw_artist(self.nextmask) + + self.update() + self.flush_events() + + def HideMask(self): + + if self.button_hidemask_check.isChecked(): + self.ax.draw_artist(self.currplot) + self.ax2.draw_artist(self.previousplot) + self.ax3.draw_artist(self.nextplot) + self.update() + self.flush_events() + + else: + self.Update3Plots() + + + def ShowCellNumbersCurr(self): + """This function is called to display the cell values and it + takes 10 random points inside of the cell, computes the mean of these + points and this gives the coordinate where the number will be + displayed. The number to be displayed is just given by the value + in the mask of the cell. + This function is just used for the current time subplot. + """ + + + for i,a in enumerate(self.ann_list): + a.remove() + self.ann_list[:] = [] + + + if self.button_showval_check.isChecked(): + vals = np.unique(self.plotmask) + for k in vals: + x,y = (self.plotmask==k).nonzero() + sample = np.random.choice(len(x), size=20, replace=False) + meanx = np.mean(x[sample]) + meany = np.mean(y[sample]) + xtemp.append(int(round(meanx))) + ytemp.append(int(round(meany))) + val.append(k) + +## if self.button_showval_check.isChecked(): +## +## maxval = np.amax(self.plotmask) +## minval = 1 +## xtemp = [] +## ytemp = [] +## val =[] +## for k in range(minval,maxval + 1): +## if k in self.plotmask: +## indices = np.where(self.plotmask == k) +## sampley = random.choices(list(indices[0]), k = 10) +## samplex = random.choices(list(indices[1]), k = 10) +## meanx = np.mean(samplex) +## meany = np.mean(sampley) +## xtemp.append(int(round(meanx))) +## ytemp.append(int(round(meany))) +## val.append(k) + if xtemp: + for i in range(0,len(xtemp)): + ann = self.ax.annotate(str(val[i]), (xtemp[i], ytemp[i])) + self.ann_list.append(ann) + + + self.draw() +# val, ct = np.unique(self.plotmask, return_counts = True) +# print(val) +# print(ct) + + else: + + for i,a in enumerate(self.ann_list): + a.remove() + self.ann_list[:] = [] + +# self.txt.remove() + self.updatedata() + + def ShowCellNumbersPrev(self): + """This function is called to display the cell values and it + takes 10 random points inside of the cell, computes the mean of these + points and this gives the coordinate where the number will be + displayed. The number to be displayed is just given by the value + in the mask of the cell. + This function is just used for the previous time subplot. + """ + + for i,a in enumerate(self.ann_list_prev): + a.remove() + self.ann_list_prev[:] = [] + + + if self.button_showval_check.isChecked(): + + maxval = np.amax(self.prevplotmask) + minval = 1 + xtemp = [] + ytemp = [] + val = [] + + for k in range(minval,maxval + 1): + if k in self.prevplotmask: + indices = np.where(self.prevplotmask == k) + sampley = random.choices(list(indices[0]), k = 10) + samplex = random.choices(list(indices[1]), k = 10) + meanx = np.mean(samplex) + meany = np.mean(sampley) + xtemp.append(int(round(meanx))) + ytemp.append(int(round(meany))) + val.append(k) + if xtemp: + for i in range(0,len(xtemp)): + ann = self.ax2.annotate(str(val[i]), (xtemp[i], ytemp[i])) + self.ann_list_prev.append(ann) + + self.draw() + + else: + + for i,a in enumerate(self.ann_list_prev): + a.remove() + self.ann_list_prev[:] = [] + + + self.previousmask.set_data((self.prevplotmask%10+1)*(self.prevplotmask!=0)) + + self.ax2.draw_artist(self.previousplot) + self.ax2.draw_artist(self.previousmask) + self.update() + self.flush_events() + + + def ShowCellNumbersNext(self): + + """This function is called to display the cell values and it + takes 10 random points inside of the cell, computes the mean of these + points and this gives the coordinate where the number will be + displayed. The number to be displayed is just given by the value + in the mask of the cell. + This function is just used for the next time subplot. + """ + + for i,a in enumerate(self.ann_list_next): + a.remove() + self.ann_list_next[:] = [] + + + + + if self.button_showval_check.isChecked(): + + maxval = np.amax(self.nextplotmask) + minval = 1 + xtemp = [] + ytemp = [] + val = [] + for k in range(minval,maxval + 1): + if k in self.nextplotmask: + indices = np.where(self.nextplotmask == k) + sampley = random.choices(list(indices[0]), k = 10) + samplex = random.choices(list(indices[1]), k = 10) + meanx = np.mean(samplex) + meany = np.mean(sampley) + xtemp.append(int(round(meanx))) + ytemp.append(int(round(meany))) + val.append(k) + if xtemp: + for i in range(0,len(xtemp)): + ann = self.ax3.annotate(str(val[i]), (xtemp[i], ytemp[i])) + self.ann_list_next.append(ann) + + self.draw() + + else: + + for i,a in enumerate(self.ann_list_next): + a.remove() + self.ann_list_next[:] = [] + + + self.nextmask.set_data((self.nextplotmask%10+1)*(self.nextplotmask!=0)) + + self.ax3.draw_artist(self.nextplot) + self.ax3.draw_artist(self.nextmask) + self.update() + self.flush_events() + + + + def updateplot(self, posx, posy): + + + + +# it updates the plot once the user clicks on the plot and draws a 4x4 pixel dot +# at the coordinate of the click +# self.modulomask = self.plotmask.copy() + xtemp, ytemp = self.storemouseclicks[0] +# remove the first coordinate as it should only coorespond +# to the value that the user wants to attribute to the drawn region + +# here we initialize the value attributed to the pixels. +# it means that the first click selects the value that will be attributed to +# the pixels inside the polygon (drawn by the following mouse clicks of the user) + + self.cellval = self.plotmask[ytemp, xtemp] + + + + +# drawing the 2x2 square ot of the mouse click + + + if (self.button_newcell_check.isChecked() or self.button_drawmouse_check.isChecked()) and self.cellval == 0: + self.tempmask[posy:posy+2, posx:posx+2] = 9 + else: + self.tempmask[posy:posy+2,posx:posx+2] = self.cellval + +# plot the mouseclick + + self.updatedata(False) + + + + def DrawRegion(self, flag): + """ + this method is used to draw either a new cell (flag = true) or to add a region to + an existing cell (flag = false). The flag will just be used to set the + value of pixels (= self.cellval) in the drawn region. + If flag = true, then the value will be the maximal value plus 1. Such + that it attributes a new value to the new cell. + If flag = false, then it will use the value of the first click to set + the value of the pixels in the new added region. + """ + + + +# here the values that have been changed to mark the mouse clicks are +# restored such that they don't appear when the region/new cell is +# drawn. + + + + if flag: +# if new cell is added, it sets the value of the drawn pixels to a new value +# corresponding to the new cell + self.cellval = np.amax(self.plotmask) + 1 + else: +# The first value is taken out as it is just used to set the value +# to the new region. +# self.store_values.pop(0) + self.storemouseclicks.pop(0) + + + if len(self.storemouseclicks) <= 2: +# if only two points or less have been click, it cannot make a area +# so it justs discards these values and returns. + self.storemouseclicks = list(self.storemouseclicks) + + self.storemouseclicks.clear() + + self.updatedata(True) +# self.store_values.clear() + return + + else: +# add the first point because to use the path function, one has to close +# the path by returning to the initial point. + self.storemouseclicks.append(self.storemouseclicks[0]) + +# codes are requested by the path function in order to make a polygon +# out of the points that have been selected. + + codes = np.zeros(len(self.storemouseclicks)) + codes[0] = Path.MOVETO + codes[len(codes)-1]= Path.CLOSEPOLY + codes[1:len(codes)-1] = Path.LINETO + codes = list(codes) + +# out of the coordinates of the mouse clicks and of the code, it makes +# a path/contour which corresponds to the added region/new cell. + path = Path(self.storemouseclicks, codes) + + + + self.storemouseclicks = np.array(self.storemouseclicks) + +# Take a square around the drawn region, where the drawn region fits inside. + minx = min(self.storemouseclicks[:,0]) + maxx = max(self.storemouseclicks[:,0]) + + miny = min(self.storemouseclicks[:,1]) + maxy= max(self.storemouseclicks[:,1]) + +# creates arrays of coordinates of the whole square surrounding +# the drawn region + array_x = np.arange(minx, maxx, 1) + array_y = np.arange(miny, maxy, 1) +# + array_coord = [] +# takes all the coordinates to couple them and store them in array_coord + for xi in range(0,len(array_x)): + for yi in range(0,len(array_y)): + + array_coord.append((array_x[xi], array_y[yi])) + +# path_contains_points returns an array of bool values +# where for each coordinates it tests if it is inside the path + + pix_inside_path = path.contains_points(array_coord) + +# for each coordinate where the contains_points method returned true +# the value of the self.currpicture matrix is changed, it draws the region +# defined by the user + + for j in range(0,len(pix_inside_path)): + if pix_inside_path[j]: + x,y = array_coord[j] + self.plotmask[y,x]= self.cellval + + +# once the self.currpicture has been updated it is drawn by callinf the +# updatedata method. + + self.updatedata() + + + self.storemouseclicks = list(self.storemouseclicks) + +# empty the lists ready for the next region to be drawn. + self.storemouseclicks.clear() +# self.store_values.clear() + + + +if __name__ == '__main__': + app = QApplication(sys.argv) + wind = dfb.FileBrowser() + if wind.exec_(): + nd2name1 = wind.nd2name + hdfname1 = wind.hdfname + hdfnewname = wind.newhdfentry.text() + ex = App(nd2name1, hdfname1, hdfnewname) + sys.exit(app.exec_()) + else: + app.exit() + diff --git a/disk/DialogDataBrowser.py b/disk/DialogDataBrowser.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6c8fe6b6d1a70ed2a950205ab55459eb6f14e5 --- /dev/null +++ b/disk/DialogDataBrowser.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 19 17:38:58 2019 +""" + +from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QDialog, QDialogButtonBox, QInputDialog, QLineEdit, QFormLayout, QFileDialog, QLabel +from PyQt5 import QtGui +#from PyQt5.QtGui import QIcon, QKeySequence +from PyQt5.QtCore import pyqtSignal, QObject, Qt +#import PyQt package, allows for GUI interactions + +class FileBrowser(QDialog): + + def __init__(self, *args, **kwargs): + super(FileBrowser, self).__init__(*args, **kwargs) + + self.setWindowTitle("Data file") + self.setGeometry(100,100, 800,200) + + + self.button_openxls = QPushButton('Open excel file') + self.button_openxls.setEnabled(True) + self.button_openxls.clicked.connect(self.getxlspath) + self.button_openxls.setToolTip("Browse for an xls file") + self.button_openxls.setMaximumWidth(150) + + + + self.newxlsentry = QLineEdit() + + self.xlsname = '' +# + flo = QFormLayout() +# flo.addRow('Enter Cell value 1 (integer):', self.entry1) + + + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + + + self.labelxls = QLabel() + self.labelxls.setText('No xls file selected') + + flo.addRow(self.labelxls, self.button_openxls) + +# flo.addWidget(self.button_openhdf) + flo.addRow('If no xls data file already exists, give a name to create a new file', self.newxlsentry) + + flo.addWidget(self.buttonBox) + + self.setLayout(flo) + + + + + def getxlspath(self): + self.xlsname,_ = QFileDialog.getOpenFileName(self, 'Open .xls File','', 'xls Files (*.xls)') +# print(self.nd2name) +# print(self.nd2name) + self.labelxls.setText(self.xlsname) + + \ No newline at end of file diff --git a/disk/DialogFileBrowser.py b/disk/DialogFileBrowser.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb50bb8e6e1ba354a7ca13eaee7a72c5cdc97f0 --- /dev/null +++ b/disk/DialogFileBrowser.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 19 17:38:58 2019 +""" + +from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QDialog, QDialogButtonBox, QInputDialog, QLineEdit, QFormLayout, QFileDialog, QLabel +from PyQt5 import QtGui +#from PyQt5.QtGui import QIcon, QKeySequence +from PyQt5.QtCore import pyqtSignal, QObject, Qt +#import PyQt package, allows for GUI interactions + +class FileBrowser(QDialog): + + def __init__(self, *args, **kwargs): + super(FileBrowser, self).__init__(*args, **kwargs) + + self.setWindowTitle("Open Files") + self.setGeometry(100,100, 800,200) + + + self.button_opennd2 = QPushButton('Open Image File') + self.button_opennd2.setEnabled(True) + self.button_opennd2.clicked.connect(self.getnd2path) + self.button_opennd2.setToolTip("Browse for an image file") + self.button_opennd2.setMaximumWidth(150) + + self.button_openfolder = QPushButton('Open Image Folder') + self.button_openfolder.setEnabled(True) + self.button_openfolder.clicked.connect(self.getfolder) + self.button_openfolder.setToolTip("Browse for folder with images") + self.button_openfolder.setMaximumWidth(150) + + self.button_openhdf = QPushButton('Open hdf file') + self.button_openhdf.setEnabled(True) + self.button_openhdf.clicked.connect(self.gethdfpath) + self.button_openhdf.setToolTip("Browse for an hdf file containing the masks") + self.button_openhdf.setMaximumWidth(150) + + self.newhdfentry = QLineEdit() +# self.newhdfentry(Qt.AlignLeft) + + + + self.nd2name = '' + self.hdfname = '' +# + flo = QFormLayout() +# flo.addRow('Enter Cell value 1 (integer):', self.entry1) + + + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + + + self.labelnd2 = QLabel() + self.labelnd2.setText('No nd2 file selected') + + self.labelhdf = QLabel() + self.labelhdf.setText('No hdf file selected') + + self.labelfolder = QLabel() + self.labelfolder.setText('No folder selected') + + flo.addRow(self.labelnd2, self.button_opennd2) + flo.addRow(self.labelfolder, self.button_openfolder) + flo.addRow(self.labelhdf, self.button_openhdf) +# flo.addWidget(self.button_openhdf) + flo.addRow('If no hdf file already exists, give a name to create a new file', self.newhdfentry) + + flo.addWidget(self.buttonBox) + + self.setLayout(flo) + + + + + def getnd2path(self): + self.nd2name,_ = QFileDialog.getOpenFileName(self, 'Open .nd2 File','', 'Image Files (*.nd2 *.tif *.tiff)') +# print(self.nd2name) +# print(self.nd2name) + self.labelnd2.setText(self.nd2name) + + def gethdfpath(self): + self.hdfname,_ = QFileDialog.getOpenFileName(self,'Open .hdf File','', 'hdf Files (*.h5)') + self.labelhdf.setText(self.hdfname) + + def getfolder(self): + self.nd2name = QFileDialog.getExistingDirectory(self, ("Select Image Folder")) + self.labelnd2.setText(self.nd2name) \ No newline at end of file diff --git a/disk/InteractionDisk_temp.py b/disk/InteractionDisk_temp.py new file mode 100644 index 0000000000000000000000000000000000000000..a52672b078190ae1499255f8c516642bd4088f19 --- /dev/null +++ b/disk/InteractionDisk_temp.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Oct 15 15:00:29 2019 + +This program reads out the images from the nd2 file and creates or +reads the hdf file containing the segmentation. +""" +from nd2reader import ND2Reader +#import matplotlib.pyplot as plt +import numpy as np + +import h5py +import os.path +import skimage.io +#import segment as seg +import neural_network as nn +import pytiff + + +import CellCorrespondance as cc +# import matplotlib.pyplot as plt + + + +class Reader: + + def __init__(self, hdfpathname, newhdfname, nd2pathname): + + +# Initializes the data corresponding to the sizes of the pictures, +# the number of different fields of views(Npos) taken in the experiment. +# And it also sets the number of time frames per field of view. + + # Identify filetype of image file + _, self.extension = os.path.splitext(nd2pathname) + self.isnd2 = self.extension == '.nd2' + self.istiff = self.extension == '.tif' or self.extension == '.tiff' + self.isfolder = self.extension == '' + + + self.nd2path = nd2pathname # path name is nd2path for legacy reasons + self.hdfpath = hdfpathname + self.newhdfname = newhdfname + + if self.isnd2: + with ND2Reader(self.nd2path) as images: + self.sizex = images.sizes['x'] + self.sizey = images.sizes['y'] + self.sizec = images.sizes['c'] + self.sizet = images.sizes['t'] + self.Npos = images.sizes['v'] + self.channel_names = images.metadata['channels'] + + elif self.istiff: + with pytiff.Tiff(self.nd2path) as handle: + self.sizex, self.sizey = handle.shape + self.sizec = 1 + self.sizet = handle.number_of_pages + self.Npos = 1 + self.channel_names = ['Channel1'] + + elif self.isfolder: + filelist = os.listdir(self.nd2path) + im = skimage.io.imread(self.nd2path + '/' + filelist[0]) + self.sizex, self.sizey = im.shape + self.sizec = 1 + self.Npos = 1 + self.sizet = len(filelist) + print(self.sizet) + self.channel_names = ['Channel1'] + + #create the labels which index the masks with respect to time and + #fov indices in the hdf5 file + self.fovlabels = [] + self.tlabels = [] + self.InitLabels() + + self.default_channel = 0 + + self.name = self.hdfpath + + self.predictname = '' + self.thresholdname = '' + self.segmentname = '' + +# self.channelwindow = chch.CustomDialog(self) +# +# if self.channelwindow.exec_(): +# +# self.default_channel = self.channelwindow.button_channel.currentIndex() + + +# create an new hfd5 file if no one existing already + self.Inithdf() + + + + def InitLabels(self): + """Create two lists containing all the possible fields of view and time + labels, in order to access the arrays in the hdf5 file. + """ + + for i in range(0, self.Npos): + self.fovlabels.append('FOV' + str(i)) + + for j in range(0, self.sizet): + self.tlabels.append('T'+ str(j)) + + + + def Inithdf(self): + """If the file already exists then it is loaded else + a new hdf5 file is created and for every fields of view + a new group is created in the createhdf method + """ + + if not self.hdfpath: + return self.Createhdf() + else: +# + temp = self.hdfpath[:-3] + + self.thresholdname = temp + '_thresholded' + '.h5' + self.segmentname = temp + '_segmented' + '.h5' + self.predictname = temp + '_predicted' + '.h5' + +# + + def Createhdf(self): + + """In this method, for each field of view one group is created. And + in each one of these group, there will be for each time frame a + corresponding dataset equivalent to a 2d array containing the + corresponding masks data (segmented/thresholded/predicted). + """ +# print('createhdf') + + self.hdfpath = '' + templist = self.nd2path.split('/') + for k in range(0, len(templist)-1): + self.hdfpath = self.hdfpath+templist[k]+'/' + + self.hdfpath = self.hdfpath + self.newhdfname + '.h5' + + hf = h5py.File(self.hdfpath, 'w') + + for i in range(0, self.Npos): + + grpname = self.fovlabels[i] + hf.create_group(grpname) + + hf.close() + + + + + for k in range(0, len(templist)-1): + self.thresholdname = self.thresholdname+templist[k]+'/' + self.thresholdname = self.thresholdname + self.newhdfname + '_thresholded' + '.h5' + + hf = h5py.File(self.thresholdname,'w') + + for i in range(0, self.Npos): + + grpname = self.fovlabels[i] + hf.create_group(grpname) + + hf.close() + + for k in range(0, len(templist)-1): + self.segmentname = self.segmentname+templist[k]+'/' + self.segmentname = self.segmentname + self.newhdfname + '_segmented' + '.h5' + + hf = h5py.File(self.segmentname,'w') + + for i in range(0, self.Npos): + + grpname = self.fovlabels[i] + hf.create_group(grpname) + + hf.close() + + for k in range(0, len(templist)-1): + self.predictname = self.predictname+templist[k]+'/' + self.predictname = self.predictname + self.newhdfname + '_predicted' + '.h5' + + hf = h5py.File(self.predictname,'w') + + for i in range(0, self.Npos): + + grpname = self.fovlabels[i] + hf.create_group(grpname) + + hf.close() + + def LoadMask(self, currentT, currentFOV): + """this method is called when one mask should be loaded from the file + on the disk to the user's buffer. If there is no mask corresponding + in the file, it creates the mask corresponding to the given time and + field of view index and returns an array filled with zeros. + """ + + file = h5py.File(self.hdfpath,'r+') + if self.TestTimeExist(currentT,currentFOV,file): + mask = np.array(file['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])], dtype = np.uint16) + file.close() + + + return mask + + + else: + +# change with Matthias code! + + zeroarray = np.zeros([self.sizey, self.sizex],dtype = np.uint16) + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = zeroarray, compression = 'gzip') + file.close() + return zeroarray + + + def TestTimeExist(self,currentT, currentFOV, file): + """This method tests if the array which is requested by LoadMask + already exists or not in the hdf file. + """ + + for t in file['/{}'.format(self.fovlabels[currentFOV])].keys(): + if t == self.tlabels[currentT]: + return True + + return False + + + + def SaveMask(self, currentT, currentFOV, mask): + """This function is called when the user wants to save the mask in the + hdf5 file on the disk. It overwrites the existing array with the new + one given in argument. + If it is a new mask, there should already + be an existing null array which has been created by the LoadMask method + when the new array has been loaded/created in the main before calling + this save method. + """ + + file = h5py.File(self.hdfpath, 'r+') + + if self.TestTimeExist(currentT,currentFOV,file): + dataset= file['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])] + dataset[:] = mask + file.close() + + else: + + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = mask, compression = 'gzip') + file.close() + + + def SaveThresholdMask(self, currentT, currentFOV, mask): + """This function is called when the user wants to save the mask in the + hdf5 file on the disk. It overwrites the existing array with the new + one given in argument. + If it is a new mask, there should already + be an existing null array which has been created by the LoadMask method + when the new array has been loaded/created in the main before calling + this save method. + """ + + file = h5py.File(self.thresholdname, 'r+') + + if self.TestTimeExist(currentT,currentFOV,file): + dataset = file['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])] + dataset[:] = mask + file.close() + else: + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = mask, compression = 'gzip') + file.close() + + def SaveSegMask(self, currentT, currentFOV, mask): + """This function is called when the user wants to save the mask in the + hdf5 file on the disk. It overwrites the existing array with the new + one given in argument. + If it is a new mask, there should already + be an existing null array which has been created by the LoadMask method + when the new array has been loaded/created in the main before calling + this save method. + """ + + file = h5py.File(self.segmentname, 'r+') + + if self.TestTimeExist(currentT,currentFOV,file): + + dataset = file['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])] + dataset[:] = mask + file.close() + else: + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = mask, compression = 'gzip') + file.close() + + + + def TestIndexRange(self,currentT, currentfov): + """this method receives the time and the fov index and checks + if it is present in the images data. + """ + + if currentT < (self.sizet-1) and currentfov < self.Npos: + return True + if currentT == self.sizet - 1 and currentfov < self.Npos: + return False +# + def LoadOneImage(self,currentT, currentfov): + """This method returns from the nd2 file, the picture requested by the + main program as an array. It fixes the fov index and iterates over the + time index. + """ + + if not (currentT < self.sizet and currentfov < self.Npos): + return None + + if self.isnd2: + with ND2Reader(self.nd2path) as images: + images.default_coords['v'] = currentfov + images.default_coords['c'] = self.default_channel + images.iter_axes = 't' + im = images[currentT] + + elif self.istiff: + with pytiff.Tiff(self.nd2path) as handle: + handle.set_page(currentT) + im = handle[:] + + elif self.isfolder: + filelist = os.listdir(self.nd2path) + im = skimage.io.imread(self.nd2path + '/' + filelist[currentT]) + + return np.array(im, dtype = np.uint16) + + + def LoadSeg(self, currentT, currentFOV): + + + file = h5py.File(self.segmentname, 'r+') + + if self.TestTimeExist(currentT,currentFOV,file): + mask = np.array(file['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])], dtype = np.uint16) + file.close() + + return mask + + + else: + + zeroarray = np.zeros([self.sizey, self.sizex],dtype = np.uint16) + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = zeroarray, compression = 'gzip', compression_opts = 7) + + file.close() + return zeroarray + + + + def LoadThreshold(self, currentT, currentFOV): + + + file = h5py.File(self.thresholdname, 'r+') + + if self.TestTimeExist(currentT,currentFOV,file): + + mask = np.array(file['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])], dtype = np.uint16) + file.close() + + return mask + + + else: + + zeroarray = np.zeros([self.sizey, self.sizex],dtype = np.uint16) + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT]), data = zeroarray, compression = 'gzip', compression_opts = 7) + + file.close() + return zeroarray + + def Segment(self, segparamvalue, currentT, currentFOV): + print(segparamvalue) + +# Check if thresholded version exists + filethr = h5py.File(self.thresholdname, 'r+') + fileprediction = h5py.File(self.predictname,'r+') # SJR: added to read out the prediction as well + +# if self.TestTimeExist(currentT, currentFOV, filethr): + if self.TestTimeExist(currentT, currentFOV, filethr) and self.TestTimeExist(currentT, currentFOV, fileprediction): # SJR: added to read out the prediction as well + + tmpthrmask = np.array(filethr['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])]) + pred = np.array(fileprediction['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])]) # SJR: added to read out the prediction as well + fileprediction.close() # SJR: added to read out the prediction as well + +# segmentedmask = nn.segment(tmpthrmask, segparamvalue) + segmentedmask = nn.segment(tmpthrmask, pred, segparamvalue) # SJR: added to read out the prediction as well + filethr.close() + + return segmentedmask + + else: + + filethr.close() + return np.zeros([self.sizey,self.sizex], dtype = np.uint16) + + + + def ThresholdPred(self, thvalue, currentT, currentFOV): + print(thvalue) + + fileprediction = h5py.File(self.predictname,'r+') + if self.TestTimeExist(currentT, currentFOV, fileprediction): + + pred = np.array(fileprediction['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])]) + fileprediction.close() + if thvalue == None: + thresholdedmask = nn.threshold(pred) + else: + thresholdedmask = nn.threshold(pred,thvalue) + + return thresholdedmask + else: + fileprediction.close() + return np.zeros([self.sizey, self.sizex], dtype = np.uint16) + +# def LaunchPrediction(self, currentT, currentFOV): + + + def TestPredExisting(self, currentT, currentFOV): + + file = h5py.File(self.predictname, 'r+') + if self.TestTimeExist(currentT, currentFOV, file): + file.close() + return True + else: + file.close() + return False + + + + + + def LaunchPrediction(self, currentT, currentFOV): + + """It launches the neural neutwork on the current image and creates + an hdf file with the prediction for the time T and corresponding FOV. + """ + + file = h5py.File(self.predictname, 'r+') + + + im = self.LoadOneImage(currentT, currentFOV) + pred = nn.prediction(im) + file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], + self.tlabels[currentT]), data = pred, compression = 'gzip', + compression_opts = 7) + file.close() + +# if self.isnd2: +# with ND2Reader(self.nd2path) as images: +# images.default_coords['v'] = currentFOV +# images.default_coords['c'] = self.default_channel +# images.iter_axes = 't' +# temp = images[currentT] +# temp = np.array(temp, dtype = np.uint16) +# pred = nn.prediction(temp) +# file.create_dataset('/{}/{}'.format(self.fovlabels[currentFOV], +# self.tlabels[currentT]), data = pred, compression = 'gzip', +# compression_opts = 7) +# +# elif self.istiff: +# None + + + + + + def CellCorrespondance(self, currentT, currentFOV): + print('in cell Correspondance') + filemasks = h5py.File(self.hdfpath, 'r+') + fileseg = h5py.File(self.segmentname,'r+') + if self.TestTimeExist(currentT-1, currentFOV, filemasks): + + if self.TestTimeExist(currentT, currentFOV, fileseg): + print('inside cellcorerspoindacefunction') + prevmask = np.array(filemasks['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT-1])]) + nextmask = np.array(fileseg['/{}/{}'.format(self.fovlabels[currentFOV], self.tlabels[currentT])]) + newmask, notifymask = cc.CellCorrespondancePlusTheReturn(nextmask, prevmask) + filemasks.close() + fileseg.close() + return newmask, notifymask + + else: + filemasks.close() + fileseg.close() + null = np.zeros([self.sizey, self.sizex]) + + return null, null + else: + + filemasks.close() + fileseg.close() + null = np.zeros([self.sizey, self.sizex]) + return null, null + + def LoadImageChannel(self,currentT, currentFOV, ch): + if self.isnd2: + with ND2Reader(self.nd2path) as images: + images.default_coords['v'] = currentFOV + images.default_coords['t'] = currentT + images.iter_axes = 'c' + im = images[ch] + return np.array(im) + + elif self.istiff: + return self.LoadOneImage(currentT, currentFOV) + + elif self.isfolder: + return self.LoadOneImage(currentT, currentFOV) + + + + diff --git a/icons/HomeIcon.png b/icons/HomeIcon.png new file mode 100644 index 0000000000000000000000000000000000000000..bf810fa88a99087ab2cb0519dfd0e8d5002f416c Binary files /dev/null and b/icons/HomeIcon.png differ diff --git a/icons/LeftArrowIcon.png b/icons/LeftArrowIcon.png new file mode 100644 index 0000000000000000000000000000000000000000..397a1861bba28bc941ea17d7c1f387c72c050180 Binary files /dev/null and b/icons/LeftArrowIcon.png differ diff --git a/icons/MoveArrowsIcon.png b/icons/MoveArrowsIcon.png new file mode 100644 index 0000000000000000000000000000000000000000..e106afa8cc6ca64e84c2de089ee1f105b5368256 Binary files /dev/null and b/icons/MoveArrowsIcon.png differ diff --git a/icons/RightArrowIcon.png b/icons/RightArrowIcon.png new file mode 100644 index 0000000000000000000000000000000000000000..e8f6f92a9b324a944ebe39a12023b9a783bace42 Binary files /dev/null and b/icons/RightArrowIcon.png differ diff --git a/icons/ZoomIcon.png b/icons/ZoomIcon.png new file mode 100644 index 0000000000000000000000000000000000000000..6d42e82582e0549b3eee2f109724bee8c785a754 Binary files /dev/null and b/icons/ZoomIcon.png differ diff --git a/icons/brush2.png b/icons/brush2.png new file mode 100644 index 0000000000000000000000000000000000000000..4821b3cd98809d7ed75695bc4729b8925cfa0ed7 Binary files /dev/null and b/icons/brush2.png differ diff --git a/icons/eraser.png b/icons/eraser.png new file mode 100644 index 0000000000000000000000000000000000000000..5b83d116867ddcd1c8b51b6fd9f608190f8eb234 Binary files /dev/null and b/icons/eraser.png differ diff --git a/init/InitButtons.py b/init/InitButtons.py new file mode 100644 index 0000000000000000000000000000000000000000..538ff22745ca4af58909be80fcdcc110b76d9b5c --- /dev/null +++ b/init/InitButtons.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- +""" +Initializing all the buttons in this file. +""" +from PyQt5.QtWidgets import QWidget, QPushButton, QShortcut, QComboBox, QCheckBox, QLineEdit, QStatusBar +from PyQt5 import QtGui +from PyQt5.QtCore import pyqtSignal, QObject, Qt +def Init(parent): + + +# configuration of all the buttons, some buttons just need to be toggled +# meaning that they need just to be clicked and the function associated +# to the button is called and executed. Other buttons are checkable +# meaning that until the user has not finished to use the function +# connected to the button, this button stays active (or Checked) + + + + # ADD REGION + parent.button_add_region.setCheckable(True) + parent.button_add_region.setMaximumWidth(150) + parent.button_add_region.clicked.connect(parent.clickmethod) +# the connect function calls clickmethod when the button is clicked. + parent.button_add_region.setShortcut("R") + parent.button_add_region.setToolTip("Use R Key for shortcut") + parent.button_add_region.setStatusTip('The first left click sets the value for the next clicks, then draw polygons') + + +# NEW CELL + parent.button_newcell.setCheckable(True) + parent.button_newcell.setMaximumWidth(150) + parent.button_newcell.setShortcut("N") + parent.button_newcell.setToolTip("Use N Key for shortcut") +# when the button is pushed the ClickNewCell method is called. + parent.button_newcell.clicked.connect(parent.ClickNewCell) + parent.button_newcell.setStatusTip('Use the left click to produce a polygon with a new cell value') + +# NEXT FRAME (TIME AXIS) +# this button is used to navigate in the time axis, to show the t+1 frame +# parent.button_nextframe.setCheckable(True) + parent.button_nextframe.toggle() + parent.button_nextframe.pressed.connect(parent.ChangeNextFrame) +# parent.button_nextframe.released.connect(parent.test2) + parent.button_nextframe.setToolTip("Use right arrow key for shortcut") + parent.button_nextframe.setMaximumWidth(150) + parent.button_nextframe.setShortcut(Qt.Key_Right) +# parent.buttonlist.append(parent.button_nextframe) +# parent.button_nextframe.setShortcutAutoRepeat(False) + + +# PREVIOUS FRAME (TIME AXIS) + +# this button is used to load the previous frame in the time axis +# parent.button_previousframe.setCheckable(True) + parent.button_previousframe.setEnabled(False) + parent.button_previousframe.toggle() + parent.button_previousframe.pressed.connect(parent.ChangePreviousFrame) + parent.button_previousframe.setToolTip("Use left arrow key for shortcut") + parent.button_previousframe.setMaximumWidth(150) + parent.button_previousframe.setShortcut(Qt.Key_Left) + parent.button_previousframe.move(100,100) + + + +# Initializing the buttons appearing in the navigation toolbar of the +# matplotlib library and connect them to the original functions already +# implemented in the navigation toolbar + +# ZOOM + parent.button_zoom = QPushButton() + parent.button_zoom.clicked.connect(parent.ZoomTlbar) + parent.button_zoom.setIcon(QtGui.QIcon('./icons/ZoomIcon.png')) + parent.button_zoom.setMaximumWidth(30) + parent.button_zoom.setMaximumHeight(30) + parent.button_zoom.setStyleSheet("border: 0px" "QPushButton:hover { background-color: blue }" ) + parent.button_zoom.setCheckable(True) + parent.button_zoom.setShortcut("Z") + parent.button_zoom.setToolTip("Use Z Key for shortcut") + parent.buttonlist.append(parent.button_zoom) + +# HOME + parent.button_home = QPushButton() + parent.button_home.clicked.connect(parent.HomeTlbar) + parent.button_home.setIcon(QtGui.QIcon('./icons/HomeIcon.png')) + parent.button_home.setMaximumWidth(30) + parent.button_home.setMaximumHeight(30) + parent.button_home.setStyleSheet("border: 0px" "QPushButton:hover { background-color: blue }" ) + parent.button_home.setShortcut("H") + parent.button_home.setToolTip("Use H Key for shortcut") + parent.buttonlist.append(parent.button_home) + +# PREVIOUS SCALE (ZOOM SCALE) + parent.button_back = QPushButton() + parent.button_back.clicked.connect(parent.BackTlbar) + parent.button_back.setIcon(QtGui.QIcon('./icons/LeftArrowIcon.png')) + parent.button_back.setMaximumWidth(30) + parent.button_back.setMaximumHeight(30) + parent.button_back.setStyleSheet("border: 0px" "QPushButton:hover { background-color: blue }" ) + parent.buttonlist.append(parent.button_back) + + +# NEXT SCALE (ZOOM SCALE) + parent.button_forward = QPushButton() + parent.button_forward.clicked.connect(parent.ForwardTlbar) + parent.button_forward.setIcon(QtGui.QIcon('./icons/RightArrowIcon.png')) + parent.button_forward.setMaximumWidth(30) + parent.button_forward.setMaximumHeight(30) + parent.button_forward.setStyleSheet("border: 0px" "QPushButton:hover { background-color: blue }" ) + parent.buttonlist.append(parent.button_forward) + +# PAN + parent.button_pan = QPushButton() + parent.button_pan.clicked.connect(parent.PanTlbar) + parent.button_pan.setIcon(QtGui.QIcon('./icons/MoveArrowsIcon.png')) + parent.button_pan.setMaximumWidth(30) + parent.button_pan.setMaximumHeight(30) + parent.button_pan.setStyleSheet("border: 0px" "QPushButton:hover { background-color: blue }" ) + parent.button_pan.setCheckable(True) + parent.button_pan.setShortcut("P") + parent.button_pan.setToolTip("Use P Key for shortcut") + parent.buttonlist.append(parent.button_pan) + + +# SAVE +# this button is used to save the current mask + parent.button_savemask.toggle() + parent.button_savemask.setEnabled(True) + parent.button_savemask.clicked.connect(parent.ButtonSaveMask) + parent.button_savemask.setToolTip("Use S Key for shortcut") + parent.button_savemask.setMaximumWidth(150) + parent.button_savemask.setShortcut("S") + parent.button_savemask.setStatusTip('Save the mask') + + +# BRUSH +# parent.button_drawmouse.setEnabled(True) + parent.button_drawmouse.setCheckable(True) + parent.button_drawmouse.clicked.connect(parent.MouseDraw) + parent.button_drawmouse.setToolTip("Use B Key for shortcut") + parent.button_drawmouse.setShortcut("B") + parent.button_drawmouse.setMaximumWidth(150) + parent.button_drawmouse.setStatusTip('Right click to select cell value and then left click and drag to draw') + + +# ERASER + parent.button_eraser.setCheckable(True) + parent.button_eraser.clicked.connect(parent.MouseDraw) + parent.button_eraser.setToolTip("Use E Key for shortcut") + parent.button_eraser.setShortcut("E") + parent.button_eraser.setMaximumWidth(150) + parent.button_eraser.setStatusTip('Right click and drag to set values to 0') + +# EXCHANGE CELL VALUES + parent.button_exval.toggle() + parent.button_exval.setEnabled(True) + parent.button_exval.clicked.connect(parent.DialogBoxECV) + parent.button_exval.setMaximumWidth(150) + parent.button_exval.setStatusTip('Exchange values between two cells') + +# CHANGE CELL VALUE + parent.button_changecellvalue.setCheckable(True) + parent.button_changecellvalue.setEnabled(True) + parent.button_changecellvalue.clicked.connect(parent.ChangeOneValue) + parent.button_changecellvalue.setMaximumWidth(150) + parent.button_changecellvalue.setStatusTip('Change value of one cell') + parent.button_changecellvalue.setToolTip('Use left click to select one cell and enter a new value') + + + +# THRESHOLD THE PREDICTION CHECKBOX + parent.button_threshold.setEnabled(False) + parent.button_threshold.stateChanged.connect(parent.ThresholdBoxCheck) + +# TEXT BOX FOR ENTERING THRESHOLD VALUE + parent.button_SetThreshold = QLineEdit() + parent.button_SetThreshold.setPlaceholderText('Enter a threshold value') + parent.button_SetThreshold.setValidator(QtGui.QDoubleValidator()) + parent.button_SetThreshold.setMaximumWidth(150) + parent.button_SetThreshold.returnPressed.connect(parent.ThresholdPrediction) + parent.button_SetThreshold.setEnabled(False) + + + +# SAVE BUTTON FOR THE THRESHOLDED PREDICTION + parent.button_savethresholdmask = QPushButton('Save Threshold') + parent.button_savethresholdmask.toggle() + parent.button_savethresholdmask.setEnabled(False) + parent.button_savethresholdmask.clicked.connect(parent.ButtonSaveThresholdMask) + parent.button_savethresholdmask.setMaximumWidth(150) + parent.button_savethresholdmask.setStatusTip('Save the thresholded prediction') + + +# SEGMENT THE OUTPUT OF THE THRESHOLD + parent.button_segment.setEnabled(False) + parent.button_segment.stateChanged.connect(parent.SegmentBoxCheck) + +# TEXT BOX FOR THE PARAMETERS OF THE SEGMENTATION + parent.button_SetSegmentation = QLineEdit() + parent.button_SetSegmentation.setPlaceholderText('Enter param for seg') + parent.button_SetSegmentation.setValidator(QtGui.QIntValidator()) + parent.button_SetSegmentation.setMaximumWidth(80) + parent.button_SetSegmentation.returnPressed.connect(parent.SegmentThresholdedPredMask) + parent.button_SetSegmentation.setEnabled(False) + parent.button_SetSegmentation.setText('10') + + +# SAVE BUTTON FOR THE SEGMENTATION OF THE THRESHOLDED PREDICTION + parent.button_savesegmask = QPushButton('Save Seg') + parent.button_savesegmask.toggle() + parent.button_savesegmask.setEnabled(False) + parent.button_savesegmask.clicked.connect(parent.ButtonSaveSegMask) + parent.button_savesegmask.setMaximumWidth(150) + parent.button_savesegmask.setStatusTip('Save the segmented thresholded prediction') + + + +# MAKE THE CELL CORRESPONDANCE + parent.button_cellcorespondance.setEnabled(False) + parent.button_cellcorespondance.setCheckable(True) + parent.button_cellcorespondance.clicked.connect(parent.CellCorrespActivation) + parent.button_cellcorespondance.setMaximumWidth(150) + parent.button_cellcorespondance.setStatusTip('Do the cell correspondance with the previous time frame') + + +# EXTRACT FLUORESCENCE IN DIFFERENT CHANNELS + + parent.button_extractfluorescence.setEnabled(False) + parent.button_extractfluorescence.toggle() + parent.button_extractfluorescence.clicked.connect(parent.ButtonFluo) + parent.button_extractfluorescence.setMaximumWidth(150) + parent.button_extractfluorescence.setStatusTip('Extract the total intensity, area and variance of the cells in the different channels') + + +# SHOW THE VALUES OF THE CELLS +# parent.button_showval.checkStateSet(False) + parent.button_showval.stateChanged.connect(parent.m.ShowCellNumbersCurr) + parent.button_showval.stateChanged.connect(parent.m.ShowCellNumbersPrev) + parent.button_showval.stateChanged.connect(parent.m.ShowCellNumbersNext) + parent.button_showval.setShortcut('V') + parent.button_showval.setToolTip("Use V Key for shortcut") + + +# HIDE/SHOW THE MASK + parent.button_hidemask.stateChanged.connect(parent.m.HideMask) + + +# CHANGE TIME INDEX + parent.button_timeindex = QLineEdit() + parent.button_timeindex.setPlaceholderText('Time index 0-{}'.format(parent.reader.sizet-1)) + parent.button_timeindex.setValidator(QtGui.QIntValidator(0,int(parent.reader.sizet-1))) + parent.button_timeindex.returnPressed.connect(parent.ChangeTimeFrame) + parent.button_timeindex.setMaximumWidth(150) + + parent.buttonlist.append(parent.button_timeindex) + + + +# FIELDS OF VIEW BUTTON +# Create a widget that displays the lists of different Fields of view. + parent.button_fov = QComboBox() +# create a list of different FOV which will be dispalyed in the widget. + list_fov = [] + for i in range(0, parent.reader.Npos): + list_fov.append("Field of View " + str(i+1)) + parent.button_fov.addItems(list_fov) + parent.button_fov.setMaximumWidth(150) + parent.buttonlist.append(parent.button_fov) +# connects the selection of one option in the ComboBox to the selectFOV +# function. + parent.button_fov.activated.connect(parent.SelectFov) + + + +# CHANGE CHANNEL BUTTON +# Create a widget that displays a list of the different channels. +# in order to switch between them + parent.button_channel = QComboBox() + parent.button_channel.addItems(parent.reader.channel_names) + parent.button_channel.setMaximumWidth(150) + parent.buttonlist.append(parent.button_channel) +# connects the selection of one option in the ComboBox to the selectFOV +# function. + parent.button_channel.activated.connect(parent.SelectChannel) + + + + + + +# NEURAL NETWORK BUTTON + parent.button_cnn.setCheckable(True) + parent.button_cnn.pressed.connect(parent.LaunchBatchPrediction) + parent.button_cnn.setToolTip("Launches the CNN on a range of images") + parent.button_cnn.setMaximumWidth(150) + parent.EnableCNNButtons() + + diff --git a/init/InitLayout.py b/init/InitLayout.py new file mode 100644 index 0000000000000000000000000000000000000000..15b53144ca776baf72c79460b7852d471dd7e49e --- /dev/null +++ b/init/InitLayout.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +""" +Initializing the layout of the main window. It places the buttons and the +pictures at the desired positions. +""" +#from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QCheckBox, QLineEdit, QMenu, QAction, QStatusBar +#from PyQt5 import QtGui +#from PyQt5.QtCore import pyqtSignal, QObject, Qt +from matplotlib.backends.qt_compat import QtCore, QtWidgets, is_pyqt5 + + +def Init(parent): + +# LAYOUT OF THE MAIN WINDOW + layout = QtWidgets.QVBoxLayout(parent._main) + + +# LAYOUT FOR THE THRESHOLD BUTTONS +# all the buttons of the threshold function are placed in an horizontal +# stack + hbox_threshold = QtWidgets.QHBoxLayout() + hbox_threshold.addWidget(parent.button_threshold) + hbox_threshold.addWidget(parent.button_SetThreshold) + hbox_threshold.addWidget(parent.button_savethresholdmask) +# if this line is not put, then the buttons are placed along the whole +# length of the window, in this way they are all grouped to the left. + hbox_threshold.addStretch(1) + + + +# LAYOUT FOR THE SEGMENTATION BUTTONS +# all the buttons of the segment function are placed in an +# horizontal stack + hbox_segment = QtWidgets.QHBoxLayout() + hbox_segment.addWidget(parent.button_segment) + hbox_segment.addWidget(parent.button_SetSegmentation) + hbox_segment.addWidget(parent.button_savesegmask) +# if this line is not put, then the buttons are placed along the whole +# length of the window, in this way they are all grouped to the left. + hbox_segment.addStretch(1) + + +# LAYOUT FOR THE TOOLBAR BUTTONS +# all the buttons of the toolbar are placed in an +# horizontal stack + hbox = QtWidgets.QHBoxLayout() + + hbox.addWidget(parent.button_home) + hbox.addWidget(parent.button_back) + hbox.addWidget(parent.button_forward) + hbox.addWidget(parent.button_pan) + hbox.addWidget(parent.button_zoom) + + hbox.addStretch(1) +# we add all our widgets in the layout + layout.addLayout(hbox) +# layout.addWidget(parent.button_zoom) +# layout.addWidget + + +# makes a horizontal layout for the buttons used to navigate through +# the time axis. + hboxtimeframes = QtWidgets.QHBoxLayout() + hboxtimeframes.addWidget(parent.button_previousframe) + hboxtimeframes.addWidget(parent.button_timeindex) + hboxtimeframes.addWidget(parent.button_nextframe) + layout.addWidget(parent.m) + + layout.addLayout(hboxtimeframes) + +# layout.addStretch(0.7) + + hboxcorrectionsbuttons = QtWidgets.QHBoxLayout() + hboxcorrectionsbuttons.addWidget(parent.button_add_region) + hboxcorrectionsbuttons.addWidget(parent.button_newcell) + + hboxcorrectionsbuttons.addWidget(parent.button_drawmouse) + hboxcorrectionsbuttons.addWidget(parent.button_eraser) + hboxcorrectionsbuttons.addWidget(parent.button_savemask) + hboxcorrectionsbuttons.addStretch(1) + layout.addLayout(hboxcorrectionsbuttons) + +# layout.addStretch(0.7) + hboxcellval = QtWidgets.QHBoxLayout() + hboxcellval.addWidget(parent.button_exval) + hboxcellval.addWidget(parent.button_changecellvalue) + hboxcellval.addStretch(1) + layout.addLayout(hboxcellval) + + + hboxlistbuttons = QtWidgets.QHBoxLayout() + hboxlistbuttons.addWidget(parent.button_fov) + hboxlistbuttons.addWidget(parent.button_channel) + hboxlistbuttons.addStretch(1) + layout.addLayout(hboxlistbuttons) + + layout.addWidget(parent.button_extractfluorescence) + + hboxcheckbox = QtWidgets.QHBoxLayout() + hboxcheckbox.addWidget(parent.button_showval) + hboxcheckbox.addWidget(parent.button_hidemask) + hboxcheckbox.addStretch(1) + + layout.addLayout(hboxcheckbox) + + + layout.addWidget(parent.button_cnn) + + layout.addLayout(hbox_threshold) + + layout.addLayout(hbox_segment) + layout.addWidget(parent.button_cellcorespondance) + + diff --git a/misc/ChangeOneCellValue.py b/misc/ChangeOneCellValue.py new file mode 100644 index 0000000000000000000000000000000000000000..951f732dbbbf64f4667e7d647de7543e82ba725d --- /dev/null +++ b/misc/ChangeOneCellValue.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 19 17:38:58 2019 +""" + +from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QDialog, QDialogButtonBox, QInputDialog, QLineEdit, QFormLayout +from PyQt5 import QtGui +#from PyQt5.QtGui import QIcon, QKeySequence +from PyQt5.QtCore import pyqtSignal, QObject, Qt +#import PyQt package, allows for GUI interactions + +class CustomDialog(QDialog): + + def __init__(self, *args, **kwargs): + super(CustomDialog, self).__init__(*args, **kwargs) + + self.setWindowTitle("Change Value one cell") + self.setGeometry(100,100, 500,200) + + self.entry1 = QLineEdit() + self.entry1.setValidator(QtGui.QIntValidator()) + self.entry1.setMaxLength(4) + self.entry1.setAlignment(Qt.AlignRight) + +# self.entry2 = QLineEdit() +# self.entry2.setValidator(QtGui.QIntValidator()) +# self.entry2.setMaxLength(4) +# self.entry2.setAlignment(Qt.AlignRight) + + flo = QFormLayout() + flo.addRow('Enter Cell value (integer):', self.entry1) +# flo.addRow('Enter Cell value 2 (integer):', self.entry2) + + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + +# self.layout = QVBoxLayout() +# self.layout.addWidget(self.buttonBox + flo.addWidget(self.buttonBox) + self.setLayout(flo) + \ No newline at end of file diff --git a/misc/ExchangeCellValues.py b/misc/ExchangeCellValues.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ef8a57a5fe2883fa76aeb1a39d7820c56ce811 --- /dev/null +++ b/misc/ExchangeCellValues.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 19 17:38:58 2019 +""" + +from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QDialog, QDialogButtonBox, QInputDialog, QLineEdit, QFormLayout +from PyQt5 import QtGui +#from PyQt5.QtGui import QIcon, QKeySequence +from PyQt5.QtCore import pyqtSignal, QObject, Qt +#import PyQt package, allows for GUI interactions + +class CustomDialog(QDialog): + + def __init__(self, *args, **kwargs): + super(CustomDialog, self).__init__(*args, **kwargs) + + self.setWindowTitle("Exchange Cell Values") + self.setGeometry(100,100, 500,200) + + self.entry1 = QLineEdit() + self.entry1.setValidator(QtGui.QIntValidator()) + self.entry1.setMaxLength(4) + self.entry1.setAlignment(Qt.AlignRight) + + self.entry2 = QLineEdit() + self.entry2.setValidator(QtGui.QIntValidator()) + self.entry2.setMaxLength(4) + self.entry2.setAlignment(Qt.AlignRight) + + flo = QFormLayout() + flo.addRow('Enter Cell value 1 (integer):', self.entry1) + flo.addRow('Enter Cell value 2 (integer):', self.entry2) + + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + +# self.layout = QVBoxLayout() +# self.layout.addWidget(self.buttonBox + flo.addWidget(self.buttonBox) + self.setLayout(flo) + \ No newline at end of file diff --git a/packages b/packages new file mode 100644 index 0000000000000000000000000000000000000000..8e27001716ac8e7e3c2dd4c04b7973366962e3e4 --- /dev/null +++ b/packages @@ -0,0 +1,10 @@ +install numpy +install PyQt5 +install matplotlib +install nd2reader==3.2.1 +install h5py +install scikit-image +install openpyxl +install Tensorflow==1.9.0 +install Keras +install pytiff diff --git a/unet/CellCorrespondance.py b/unet/CellCorrespondance.py new file mode 100644 index 0000000000000000000000000000000000000000..8bef01a5848dfc160852e64ac08279dff6c04857 --- /dev/null +++ b/unet/CellCorrespondance.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Dec 12 10:03:28 2019 + +Cell Correspondance +""" + +import numpy as np +import matplotlib.pyplot as plt + + + + + + + +def cell_frame_intersection(nextmask, prevmask, flag): + """ + This function computes the intersection between the segmented frame and + the previous frame. It looks at all the segments in the new mask + (= nextmask) and for each of the segments it finds the corresponding + coordinates. These coordinates are then plugged into the previous mask + and in the corresponding region of the previous mask it takes the cell + which has the most pixel values in this region (compare the pixel counts). + And then takes the value of the value which has the biggest number + of counts in the corresponding region and sets to the segment in the new + mask. + If the cell value given by the intersection between the frame is = 0, + than it means that the new segment intersects mostly with background, + so it is assumed that a new cell is created and a new value is given + to this cell. + of course, it does not work for all the cells (some move, some grow) so + the intersection might occur with + """ + nwcells = np.unique(nextmask) + nwcells = nwcells[nwcells >0] + vals = np.empty(len(nwcells), dtype=int) + vals[:] = -1 + coord = np.empty(len(nwcells)) + coord = list(coord) + maxvalue = np.amax(prevmask) + double_smaller = [] + double_bigger = [] + + for i, cell in enumerate(nwcells): + + v, c = np.unique(prevmask[nextmask==cell], return_counts=True) + + valtemp = v[np.argmax(c)] + coord[i] = nextmask == cell + + if valtemp != 0: + if valtemp in vals: + tempind = [] + cts = [] + + for k, allval in enumerate(vals): + if allval == valtemp: + + valbool, ctmp = np.unique(coord[k], return_counts = True) + + cts.append(ctmp[valbool == True]) + tempind.append(k) + +# cts = np.array(cts) + maxpreviouscts = int(np.amax(cts)) + biggestcoord_index = tempind[np.argmax(cts)] + c = int(c[np.argmax(c)]) + + + if len(tempind) == 1: + if c > maxpreviouscts: + double_smaller.append(coord[biggestcoord_index]) + double_bigger.append(nextmask == cell) + else: + double_smaller.append(nextmask == cell) + double_bigger.append(coord[biggestcoord_index]) + else: + + if maxpreviouscts >= c: + double_smaller.append(nextmask == cell) + + else: + + for k in range(0, len(double_bigger)): + if not(False in (double_bigger[k] == coord[biggestcoord_index])): + double_smaller.append(double_bigger[k]) + double_bigger[k] = coord[biggestcoord_index] + + + + vals[i] = valtemp + + + + + else: + if flag: + maxvalue = maxvalue + 1 + vals[i] = maxvalue + else: + vals[i] = 0 + + out = nextmask.copy() + for k, v in zip(nwcells, vals): + out[k==nextmask] = v + return out, double_bigger, double_smaller + + + + + + + + +def CellCorrespondance(nextmask, prevmask, returnbool): + + nm = nextmask.copy() + pm = prevmask.copy() + + NotifyRegionMask = nextmask.copy() + NotifyRegionMask[NotifyRegionMask > 0] = 0 + + CorrespondanceMask, bigcluster, smallcluster = cell_frame_intersection(nm, pm, returnbool) + + cm = CorrespondanceMask.copy() + + oldcells = np.unique(prevmask) + oldcells = oldcells[oldcells > 0] + + for cellval in oldcells: + if not(cellval in CorrespondanceMask): + + NotifyRegionMask[prevmask == cellval] = 1 + + + + for coordinates in smallcluster: + val = np.unique(prevmask[coordinates]) + for cellval in oldcells: + if not(cellval in CorrespondanceMask): + if cellval in val: + + CorrespondanceMask[coordinates] = cellval + NotifyRegionMask[coordinates] = 1 + + + NotifyRegionMask[coordinates] = 1 + + + + + + + return CorrespondanceMask, NotifyRegionMask + + + + +def CellCorrespondancePlusTheReturn(nextmask, prevmask): + + firstcorresp, notifymask = CellCorrespondance(nextmask, prevmask, True) + + firstmask = firstcorresp.copy() + pm = prevmask.copy() + fm = firstcorresp.copy() +# pm2 = prevmask.copy() + + cmreturn, bigclustr, smallclustr = cell_frame_intersection(pm, firstmask, False) + +# cmreturn, notifymask = CellCorrespondance(pm, firstmask, False) + oldcells = np.unique(fm) + oldcells = oldcells[oldcells > 0] + + for cellval in oldcells: + if not(cellval in cmreturn): + + notifymask[fm == cellval] = 2 + + + + for coordinates in smallclustr: + + val = np.unique(fm[coordinates]) + for cellval in oldcells: + if not(cellval in cmreturn): + if cellval in val: + + notifymask[coordinates] = 2 + + notifymask[coordinates] = 2 + + + + + + return fm, notifymask + + + + + + + + + + + + + + + + + + + + diff --git a/unet/LaunchBatchPrediction.py b/unet/LaunchBatchPrediction.py new file mode 100644 index 0000000000000000000000000000000000000000..e18a64a35f5fe2b493326ba7c6a231456714ff09 --- /dev/null +++ b/unet/LaunchBatchPrediction.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 19 17:38:58 2019 +""" + +from PyQt5.QtWidgets import QApplication, QMainWindow, QMenu, QVBoxLayout, QSizePolicy, QMessageBox, QWidget, QPushButton, QShortcut, QComboBox, QDialog, QDialogButtonBox, QInputDialog, QLineEdit, QFormLayout, QLabel, QListWidget, QAbstractItemView +from PyQt5 import QtGui +#from PyQt5.QtGui import QIcon, QKeySequence +from PyQt5.QtCore import pyqtSignal, QObject, Qt +#import PyQt package, allows for GUI interactions +import pdb +class CustomDialog(QDialog): + + def __init__(self, *args, **kwargs): + + super(CustomDialog, self).__init__(*args, **kwargs) + + app, = args + maxtimeindex = app.reader.sizet + maxfovindex = app.reader.Npos + + self.setWindowTitle("Launch NN") + self.setGeometry(100,100, 500,200) + + + + self.entry1 = QLineEdit() + self.entry1.setValidator(QtGui.QIntValidator(0,int(maxtimeindex-1))) + + + self.entry2 = QLineEdit() + self.entry2.setValidator(QtGui.QIntValidator(0,int(maxtimeindex-1))) + + + self.listfov = QListWidget() + self.listfov.setSelectionMode(QAbstractItemView.MultiSelection) + + for f in range(0, app.reader.Npos): + self.listfov.addItem('Field of View {}'.format(f+1)) + + + self.labeltime = QLabel("Enter range ({}-{}) for time axis".format(0, app.reader.sizet-1)) + + self.entry_threshold = QLineEdit() + self.entry_threshold.setValidator(QtGui.QDoubleValidator()) + + self.entry_segmentation = QLineEdit() + self.entry_segmentation.setValidator(QtGui.QIntValidator()) + + + flo = QFormLayout() + flo.addWidget(self.labeltime) + flo.addRow('Lower Boundary for time axis', self.entry1) + flo.addRow('Upper Boundary for time axis', self.entry2) + + + flo.addRow('Select Fields of fiew from the list', self.listfov) + + flo.addRow('Enter a threshold value', self.entry_threshold) + flo.addRow('Enter a segmentation value', self.entry_segmentation) + + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + + flo.addWidget(self.buttonBox) + self.setLayout(flo) + + + diff --git a/unet/data.py b/unet/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2f37a4a0caa59bbc0b953616f05d9deea7d98951 --- /dev/null +++ b/unet/data.py @@ -0,0 +1,127 @@ +""" +Source of the code: https://github.com/zhixuhao/unet +""" +from __future__ import print_function +from tensorflow.keras.preprocessing.image import ImageDataGenerator +import numpy as np +import os +import glob +import skimage.io as io +import skimage.transform as trans + +Sky = [128,128,128] +Building = [128,0,0] +Pole = [192,192,128] +Road = [128,64,128] +Pavement = [60,40,222] +Tree = [128,128,0] +SignSymbol = [192,128,128] +Fence = [64,64,128] +Car = [64,0,128] +Pedestrian = [64,64,0] +Bicyclist = [0,128,192] +Unlabelled = [0,0,0] + +COLOR_DICT = np.array([Sky, Building, Pole, Road, Pavement, + Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled]) + + +def adjustData(img,mask,flag_multi_class,num_class): + if(flag_multi_class): + img = img / 255 + mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0] + new_mask = np.zeros(mask.shape + (num_class,)) + for i in range(num_class): + #for one pixel in the image, find the class in mask and convert it into one-hot vector + #index = np.where(mask == i) + #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i) + #new_mask[index_mask] = 1 + new_mask[mask == i,i] = 1 + new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2])) + mask = new_mask + elif(np.max(img) > 1): + img = img / 255 + mask = mask /255 + mask[mask > 0.5] = 1 + mask[mask <= 0.5] = 0 + return (img,mask) + + + +def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale", + mask_color_mode = "grayscale",image_save_prefix = "image",mask_save_prefix = "mask", + flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1): + ''' + can generate image and mask at the same time + use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same + if you want to visualize the results of generator, set save_to_dir = "your path" + ''' + image_datagen = ImageDataGenerator(**aug_dict) + mask_datagen = ImageDataGenerator(**aug_dict) + image_generator = image_datagen.flow_from_directory( + train_path, + classes = [image_folder], + class_mode = None, + color_mode = image_color_mode, + target_size = target_size, + batch_size = batch_size, + save_to_dir = save_to_dir, + save_prefix = image_save_prefix, + seed = seed) + mask_generator = mask_datagen.flow_from_directory( + train_path, + classes = [mask_folder], + class_mode = None, + color_mode = mask_color_mode, + target_size = target_size, + batch_size = batch_size, + save_to_dir = save_to_dir, + save_prefix = mask_save_prefix, + seed = seed) + train_generator = zip(image_generator, mask_generator) + for (img,mask) in train_generator: + img,mask = adjustData(img,mask,flag_multi_class,num_class) + yield (img,mask) + + + +def testGenerator(test_path,num_image = 30,target_size = (256,256),flag_multi_class = False,as_gray = True): + for i in range(num_image): + img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = as_gray) + img = img / 255 + img = trans.resize(img,target_size) + img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img + img = np.reshape(img,(1,)+img.shape) + yield img + + +def geneTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True): + image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix)) + image_arr = [] + mask_arr = [] + for index,item in enumerate(image_name_arr): + img = io.imread(item,as_gray = image_as_gray) + img = np.reshape(img,img.shape + (1,)) if image_as_gray else img + mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray) + mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask + img,mask = adjustData(img,mask,flag_multi_class,num_class) + image_arr.append(img) + mask_arr.append(mask) + image_arr = np.array(image_arr) + mask_arr = np.array(mask_arr) + return image_arr,mask_arr + + +def labelVisualize(num_class,color_dict,img): + img = img[:,:,0] if len(img.shape) == 3 else img + img_out = np.zeros(img.shape + (3,)) + for i in range(num_class): + img_out[img == i,:] = color_dict[i] + return img_out / 255 + + + +def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2): + for i,item in enumerate(npyfile): + img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0] + io.imsave(os.path.join(save_path,"%d_predict.png"%i),img) diff --git a/unet/data_processing.py b/unet/data_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..c0836f118bd4523fb293601a5dba629dbc9e70c3 --- /dev/null +++ b/unet/data_processing.py @@ -0,0 +1,352 @@ +import numpy as np + +import skimage +from skimage import io +from skimage.util import img_as_ubyte +from skimage import morphology + + +############# +# # +# READING # +# # +############# + +def read_im_tiff(path, num_frames=None): + """ + Read a multiple page tiff images file, adapt the contrast and stock each + frame in a np.array + Param: + path: path to the tiff movie + num_frames: (integer) number of frames to read + Return: + images: + """ + + ims = io.imread(path) + images = [] + + if(num_frames == None): + num_frames = ims.shape[0] + + for i in range(num_frames): + im = skimage.exposure.equalize_adapthist(ims[i]) + images.append(np.array(im)) + + return images + + + +def read_lab_tiff(path, num_frames=None): + """ + Read a multiple page tiff label file and stock each frame in a np.array + Param: + path: path to the tiff movie + num_frames: (integer) number of frames to read + Return: + images: (np.array) array containing the desired frames + """ + + ims = io.imread(path) + images = [] + + if(num_frames == None): + + num_frames = ims.shape[0] + + for i in range(num_frames): + + images.append(np.array(ims[i])) + + return images + + + +################################ +# # +# GENERAL PROCESSING METHODS # +# # +################################ + +def pad(im, size): + """ + Carry a mirror padding on an image to end up with a squared image dividable in multiple tiles of shape size x size pixels + Param: + im: (np.array) input image + size: (integer) size of the tiles + Return: + out: (np.array) output image reshaped to the good dimension + """ + + # add mirror part along x-axis + nx = im.shape[1] + ny = im.shape[0] + + if(nx % size != 0.0): + restx = int(size - nx % size) + outx = np.pad(im, ((0,0), (0,restx)), 'symmetric') + else: + outx = im + + if(ny % size != 0.0): + resty = int(size - ny % size) + out = np.pad(outx, ((0,resty), (0,0)), 'symmetric') + else: + out = outx + + return out + +def threshold(im,th = None): + """ + Binarize an image with a threshold given by the user, or if the threshold is None, calculate the better threshold with isodata + Param: + im: a numpy array image (numpy array) + th: the value of the threshold (feature to select threshold was asked by the lab) + Return: + bi: threshold given by the user (numpy array) + """ + if th == None: + th = skimage.filters.threshold_isodata(im) + bi = im + bi[bi > th] = 255 + bi[bi <= th] = 0 + return bi + + +def edge_detection(im): + """ + Detect the edges on a label image + Param: + im: (np.array) input image + Return: + contour: (np.array) output image containing the edges of the input image + """ + + contour = np.zeros((im.shape[0], im.shape[1])) + vals = np.unique(im) + + for i in range(0,vals.shape[0]): + a = np.zeros((im.shape[0], im.shape[1])) + a[im == vals[i]] = 255 + + dilated = skimage.morphology.dilation(a, selem=None, out=None) + eroded = skimage.morphology.erosion(a, selem=None, out=None) + edge = dilated - eroded + + contour = contour + edge + + contour[contour >= 255] = 255 + + return contour + + +def split(im, size): + """ + split an squared image with good dimensions (output of pad()) into tiles of shape size x size pixels + Param: + im: (np.array) input image + size: (integer) size of the tiles + Return: + ims: (list of np.array) output tiles + """ + + nx = im.shape[1] + ny = im.shape[0] + + k_max = int(nx / size) # number of 256 slices along x axis + l_max = int(ny / size) # number of 256 slices along y axis + + ims = [] + + for l in range(0, l_max): + for k in range(0, k_max): + frame = np.zeros((size,size)) + + lo_x = size * k + hi_x = size * k + size + + lo_y = size * l + hi_y = size * l + size + + frame = im[lo_y:hi_y, lo_x:hi_x] + + # padding of the image to avoid border artefacts due to the convolutions + # out = np.pad(frame, ((10,10), (10,10)), 'symmetric') + + ims.append(frame) + return ims + + + +def split_data(im, lab, ratio, seed): + """split the dataset based on the split ratio.""" + # set seed + np.random.seed(seed) + + index= np.arange(len(im)) + np.random.shuffle(index) + num = int(ratio*len(index)) + im = im[index] + lab = lab[index] + + im_tr = im[0:num,:,:] + lab_tr = lab[0:num,:,:] + + im_te = im[num:,:,:] + lab_te = lab[num:,:,:] + + return im_tr, lab_tr, im_te, lab_te + + + +#################################### +# # +# TRAIN AND TEST SETS GENERATORS # +# # +#################################### + +def generate_test_set(im, out_im_path): + """ + Generate the testing set from raw data + Param: + im: (np.array) input image + out_im_path: path to save the testing image + + Return: + img_num: (integer) number of image produced + resized_shape: (tupple) shape of the padded image + original_shape: (tupple) shape of the original image + """ + im = skimage.exposure.equalize_adapthist(im) + + original_shape = im.shape + + #resizing the input images + padded = pad(im, 236) + resized_shape = padded.shape + + # splitting + splited = split(padded, 236) + + # padding the tiles to get 256x256 tiles + padded_split = [] + for tile in splited: + padded_split.append(np.pad(tile, ((10,10), (10,10)), 'symmetric')) + + img_num = len(padded_split) + + #saving the ouput images + for i in range(len(padded_split)): + name = str(i) + io.imsave( out_im_path + name + ".png", img_as_ubyte(padded_split[i]) ) + + return img_num, resized_shape, original_shape + + + +def generate_tr_val_set(im_col, lab_col, tr_im_path, tr_lab_path, val_im_path, val_lab_path): + """ + Randomly generate training, validation and testing set from a given collection of images and labels with a 50/25/25 ratio + for detection of whole cells + Params: + im_col: (list of string) list of images (tiff movies) to include in the sets + lab_col: (list of string) list of labels (tiff movies) to include in the sets + tr_im_path: path to save training images + tr_lab_path: path to save training labels + val_im_path: path to save validation images + val_lab_path: path to save validation labels + Returns: + tr_len: (integer) number of samples in the training set + val_len: (integer) number of samples in the validation set + """ + # reading raw data + ims = [] + labs = [] + + for im in im_col: + print('im', im) + ims = ims + read_im_tiff(im) + + for lab in lab_col: + print('label',lab) + labs = labs + read_lab_tiff(lab) + + ims_out = [] + labs_out = [] + + + for i in range( len(ims) ): + # resizing images + im_out = pad(ims[i], 236) + + # resizing and binarizing whole cell label + + threshold(labs[i],0) + lab_out = pad(labs[i], 236) + + # splitting the images + split_im = split(im_out, 236) + split_lab = split(lab_out, 236) + + # discarding images showing background only + # padding the tiles + split_im_out = [] + split_lab_out = [] + for j in range( len(split_lab) ): + if( np.sum(split_lab[j]) > 0.1 * 255 * 256 * 256 ): + split_im_out.append(np.pad(split_im[j], ((10,10), (10,10)), 'symmetric')) + split_lab_out.append(np.pad(split_lab[j], ((10,10), (10,10)), 'symmetric')) + + ims_out = ims_out + split_im_out + labs_out = labs_out + split_lab_out + + # splitting the list into multiple sets + im_tr, lab_tr, im_val, lab_val = split_data(np.array(ims_out), + np.array(labs_out), + 0.75, + 1) + + tr_len = im_tr.shape[0] + val_len = im_val.shape[0] + + #saving the images + for i in range(tr_len): + io.imsave(tr_im_path + str(i) + ".png", img_as_ubyte(im_tr[i,:,:])) + io.imsave(tr_lab_path + str(i) + ".png", (lab_tr[i,:,:])) + + for j in range(val_len): + io.imsave(val_im_path + str(j) + ".png", img_as_ubyte(im_val[j,:,:])) + io.imsave(val_lab_path + str(j) + ".png", (lab_val[j,:,:])) + + return tr_len, val_len + +def reconstruct_result(tile_size, result, resized_shape, origin_shape): + """ + Assemble a set of tiles to reconstruct the original, unsplitted image + Param: + tile_size: (integer) size of the tiles for the reconstruction + result: (np.array) result images of the network prediction + out_result_path: path to save the results + resized_shape: (tuple) size of the image padded for the splitting + origin_shape: (tuple) size or the raw images + Return: + out: (np.array) array containing the reconstructed images + """ + nx, ny = int(resized_shape[1] / tile_size), int(resized_shape[0] / tile_size) + out = np.empty(resized_shape) + + i = 0 + for l in range(ny): + for k in range(nx): + + lo_x = tile_size * k + hi_x = tile_size * k + tile_size + + lo_y = tile_size * l + hi_y = tile_size * l + tile_size + + out[lo_y:hi_y, lo_x:hi_x] = result[i,:,:] + + i = i+1 + + return out[ 0:origin_shape[0], 0:origin_shape[1] ] diff --git a/unet/model.py b/unet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..717700fcb58d12b7351336f50acd542cf4c13dda --- /dev/null +++ b/unet/model.py @@ -0,0 +1,88 @@ +""" +Source of the code: https://github.com/zhixuhao/unet +""" +import numpy as np +import os +import skimage.io as io +import skimage.transform as trans +import numpy as np +from tensorflow.keras.models import * +from tensorflow.keras.layers import * +from tensorflow.keras.optimizers import * +from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler +from tensorflow.keras import backend as keras + + +from tensorflow import ConfigProto +from tensorflow import InteractiveSession +#from tensorflow.compat.v1 import Session +#from tensorflow.compat.v1 import get_default_graph +#from tensorflow.compat.v1 import S +#import tensorflow +config = ConfigProto() +config.gpu_options.allow_growth = True +session = InteractiveSession(config=config) + +#session_conf = ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8, device_count = {'GPU':0}) +#tensorflow.random.set_seed(1) +#sess = Session(graph=get_default_graph(), config=session_conf) +#tensorflow.compat.v1.keras.backend.set_session(sess) +#gpus = tf.config.experimental.list_physical_devices('GPU') +#tf.config.experimental.set_memory_growth(gpus[0], True) + + +#os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + +def unet(pretrained_weights = None,input_size = (256,256,1)): + inputs = Input(input_size) + conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) + conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) + conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) + conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) + conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) + drop4 = Dropout(0.5)(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) + + conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) + conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) + drop5 = Dropout(0.5)(conv5) + + up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) + merge6 = concatenate([drop4,up6], axis = 3) + conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) + conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) + + up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) + merge7 = concatenate([conv3,up7], axis = 3) + conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) + conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) + + up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) + merge8 = concatenate([conv2,up8], axis = 3) + conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) + conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) + + up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) + merge9 = concatenate([conv1,up9], axis = 3) + conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) + conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) + conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) + conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) + + model = Model(inputs = inputs, outputs = conv10) + + model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) + + #model.summary() + + if(pretrained_weights): + model.load_weights(pretrained_weights) + + return model diff --git a/unet/neural_network.py b/unet/neural_network.py new file mode 100644 index 0000000000000000000000000000000000000000..5a73bdf28b00371a3c3f481e4bd2a44d525d1211 --- /dev/null +++ b/unet/neural_network.py @@ -0,0 +1,125 @@ + +# -*- coding: utf-8 -*- +""" +Created on Sat Dec 21 18:54:10 2019 + +""" +from model import * +from data import * +#from quality_measures import * +from segment import * +from data_processing import * + +import numpy as np +import skimage +from skimage import io + +def create_directory_if_not_exists(path): + """ + Create in the file system a new directory if it doesn't exist yet. + Param: + path: the path of the new directory + """ + if not os.path.exists(path): + os.makedirs(path) + + +def threshold(im,th = None): + """ + Binarize an image with a threshold given by the user, or if the threshold is None, calculate the better threshold with isodata + Param: + im: a numpy array image (numpy array) + th: the value of the threshold (feature to select threshold was asked by the lab) + Return: + bi: threshold given by the user (numpy array) + """ + im2 = im.copy() + if th == None: + th = skimage.filters.threshold_isodata(im2) + bi = im2 + bi[bi > th] = 255 + bi[bi <= th] = 0 + return bi + + +def prediction(im): + """ + Calculate the prediction of the label corresponding to image im + Param: + im: a numpy array image (numpy array), with max size 2048x2048 + Return: + res: the predicted distribution of probability of the labels (numpy array) + """ + path_test = './tmp/test/image/' + create_directory_if_not_exists(path_test) + + io.imsave(path_test+'0.png',im) + # TESTING SET +# img_num, resized_shape, original_shape = generate_test_set(im,path_test) + + # WHOLE CELL PREDICTION + testGene = testGenerator(path_test, + 1, + target_size = (2048,2048) ) + + model = unet( pretrained_weights = None, + input_size = (2048,2048,1) ) + + model.load_weights('./unet_weights_batchsize_25_Nepochs_100_full.hdf5') + + results = model.predict_generator(testGene, + 1, + verbose=1) + +# res = reconstruct_result(236, results[:,10:246,10:246,0], resized_shape,original_shape) +# if the size of the image is not 2048x2048, the difference between 2048x2048 +# and the original size is cut off here. (in the prediction it is +# artificially augmented from the original size to 2048x2048) + + index_x = 2048-len(im[:,0]) + index_y = 2048-len(im[0,:]) + #this bolean values are true if index_x(y)/2 else they are false. + #they are initialized to false. + flagx = False + flagy = False + + #test if x dimension of im is not already the max size + if index_x != 0: + #if index_x not the 0 (im has not max size in x axis) then the + #the difference is divided by two, index_x is the difference between + #size of im in x axis and max size 2048 + ind_x = index_x/2 + if index_x%2 == 0: + ind_x = int(ind_x) + flagx = True + else: + #if already max size, it is set to 0 + ind_x = 0 + flagx = True + + #test if y dimension of im is not already the max size + if index_y != 0: + #if index_y not the 0 (im has not max size in y axis) then the + #the difference is divided by two, index_y is the difference between + #size of im in y axis and max size 2048 + ind_y = index_y/2 + if index_y%2 == 0: + ind_y = int(ind_y) + flagy = True + else: + ind_y = 0 + flagy = True + + + if flagx and flagy: + res = results[0,ind_x:2048-ind_x,ind_y:2048-ind_y,0] + return res + elif not(flagx) and flagy: + res = results[0,int(ind_x):2048-(int(ind_x)+1),ind_y:2048-ind_y,0] + return res + elif flagx and not(flagy): + res = results[0,ind_x:2048-ind_x, int(ind_y):2048-(int(ind_y)+1),0] + return res + else: + res = results[0, int(ind_x):2048-(int(ind_x)+1),int(ind_y):2048-(int(ind_y)+1),0] + return res diff --git a/unet/quality_measures.py b/unet/quality_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b7d7cce2fefbda2bd87df4b008f5dac37b8285 --- /dev/null +++ b/unet/quality_measures.py @@ -0,0 +1,166 @@ +""" +Source of the code: https://github.com/mattminder/yeastSegHelpers +""" +import numpy as np + +def cell_correspondance_frame(truth, pred): + """ + Finds for one frame the correspondance between the true and the predicted cells. + Returns dictionary from predicted cell to true cell. + """ + keys = np.unique(pred) + vals = np.empty(len(keys), dtype=int) + for i, cell in enumerate(keys): + v, c = np.unique(truth[pred==cell], return_counts=True) + vals[i] = v[np.argmax(c)] + return dict(zip(keys, vals)) + +def split_dict(cc): + """ + Returns dictionary of cells in predicted image, to whether they were split unnecessarily + (includes cells that were assigned to background) + """ + truth = list(cc.values()) + pred = list(cc.keys()) + + v, c = np.unique(list(truth), return_counts=True) + split_truth = v[c>1] + return dict(zip(pred, np.isin(truth, split_truth))) + +def fused_dict(cc, truth, pred): + """ + Returns dictionary of cells in predicted image to whether they were fused unnecessarily. + """ + rev_cc = cell_correspondance_frame(pred, truth) + rev_split_dict = split_dict(rev_cc) + out = dict() + for p in cc: + out[p] = rev_split_dict[cc[p]] + return out + +def nb_splits(cc): + """ + Counts the number of unnecessary splits in the predicted image (ignores background). + """ + v, c = np.unique(list(cc.values()), return_counts=True) + return (c[v>0]-1).sum() + +def nb_fusions(cc, truth): + """ + Counts number of cells that were fused. + """ + true_cells = set(np.unique(truth)) + pred_cells = set(list(cc.values())) + return len(true_cells - pred_cells) + +def nb_artefacts(cc): + """ + Number of predicted cells that are really background. + """ + valarr = np.array(list(cc.values())) + keyarr = np.array(list(cc.keys())) + return (valarr[keyarr>0]==0).sum() + +def nb_false_negatives(truth, pred): + """ + Number of cells of mask that were detected as background in the prediction. + """ + rev_cc = cell_correspondance_frame(pred, truth) + return nb_artefacts(rev_cc) + +def over_undershoot(truth, pred, cc, look_at): + """ + Calculates average number of pixels that were overshot by cell. Uses look_at, which + is a dict from pred cells to whether they should be considered (allows to exclude wrongly + fused or split cells) + """ + overshoot = 0 + undershoot = 0 + cellcount = 0 + + for p in cc: + if not look_at[p]: + continue + t = cc[p] + if t==0 or p==0: # Disregard if truth or prediction is background + continue + cellcount += 1 + overshoot += (pred[truth!=t]==p).sum() + undershoot += (truth[pred!=p]==t).sum() + return overshoot/cellcount, undershoot/cellcount + +def average_pred_area(pred, cc, look_at): + """ + Calculates average predicted area of all considered cells + """ + area = 0 + cellcount = 0 + for p in cc: + if not look_at[p]: + continue + if p==0: + continue + cellcount += 1 + area += (pred==p).sum() + return area/cellcount + +def average_true_area(truth, cc, look_at): + """ + Calculates average true area of all considered cells + """ + area = 0 + cellcount = 0 + for p in cc: + if not look_at[p]: + continue + t = cc[p] + if t==0: + continue + cellcount += 1 + area += (truth==t).sum() + return area/cellcount + +def n_considered_cells(cc, look_at): + """ + Calculates the number of considered cells + """ + count = 0 + for p in cc: + if not look_at[p]: + continue + if p==0 or cc[p]==0: + continue + count += 1 + return count + +def quality_measures(truth, pred): + """ + Tests quality of prediction with four statistics: + Number Fusions: How many times are multiple true cells predicted as a single cell? + Number Splits: How many times is a single true cell split into multiple predicted cell? + Av. Overshoot: How many pixels are wrongly predicted to belong to the cell on average + Av. Undershoot: How many pixels are wrongly predicted to not belong to the cell on average + Av. true area: Average area of cells of truth that are neither split nor fused + Av. pred area: Average area of predicted cells that are neither split nor fused + Nb considered cells:Number of cells that are neither split nor fused + """ + cc = cell_correspondance_frame(truth, pred) + res = dict() + + # Get indices of cells that are not useable for counting under- and overshooting + is_split = split_dict(cc) + is_fused = fused_dict(cc, truth, pred) + look_at = dict() + for key in is_split: + look_at[key] = not (is_split[key] or is_fused[key]) + + # Result Calculation + res["Number Fusions"] = nb_fusions(cc, truth) + res["Number Splits"] = nb_splits(cc) + res["Nb False Positives"] = nb_artefacts(cc) + res["Nb False Negatives"] = nb_false_negatives(truth, pred) + res["Average Overshoot"], res["Average Undershoot"] = over_undershoot(truth, pred, cc, look_at) + res["Av. True Area"] = average_true_area(truth, cc, look_at) + res["Av. Pred Area"] = average_pred_area(pred, cc, look_at) + res["Nb Considered Cells"] = n_considered_cells(cc, look_at) + return res diff --git a/unet/segment.py b/unet/segment.py new file mode 100755 index 0000000000000000000000000000000000000000..8a4c805b522b43f1a0eaba283deaac5550d1cd32 --- /dev/null +++ b/unet/segment.py @@ -0,0 +1,96 @@ +""" +Source of the code: https://github.com/mattminder/yeastSegHelpers +""" +from scipy import ndimage as ndi +from skimage.feature import peak_local_max +from skimage.morphology import watershed + +# from PIL import Image +import numpy as np +from skimage import data, util, filters, color +import cv2 + +def segment(th, pred, min_distance=10, topology=None): #SJR: added pred to evaluate new borders + """ + Performs watershed segmentation on thresholded image. Seeds have to + have minimal distance of min_distance. topology defines the watershed + topology to be used, default is the negative distance transform. Can + either be an array with the same size af th, or a function that will + be applied to the distance transform. + """ + dtr = ndi.morphology.distance_transform_edt(th) + if topology is None: + topology = -dtr + elif callable(topology): + topology = topology(dtr) + + m = peak_local_max(-topology, min_distance, indices=False) + m_lab = ndi.label(m)[0] + wsh = watershed(topology, m_lab, mask=th) + + wshshape=wsh.shape # size of the watershed images, could probably get the sizes from the original input images but too lazy to figure out how + oriobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) # the masks for the original cells are saved each separately here + dilobjs=np.zeros((wsh.max()+1,wshshape[0],wshshape[1])) # the masks for dilated cells are saved each separately here + objcoords=np.zeros((wsh.max()+1,4)) # coordinates of the bounding boxes for each dilated object saved here + wshclean=np.zeros((wshshape[0],wshshape[1])) + + kernel = np.ones((3,3), np.uint8) # need kernel to dilate objects + for obj1 in range(0,wsh.max()): # objects numbered starting with 0!! Careful!! + oriobjs[obj1,:,:] = np.uint8(np.multiply(wsh==(obj1+1),1)) # create a mask for each object (number: obj1+1) + dilobjs[obj1,:,:] = cv2.dilate(oriobjs[obj1,:,:], kernel, iterations=1) # create a mask for each object (number: obj1+1) after dilation + objallpixels=np.where(dilobjs[obj1,:,:] != 0) + objcoords[obj1,0]=np.min(objallpixels[0]) + objcoords[obj1,1]=np.max(objallpixels[0]) + objcoords[obj1,2]=np.min(objallpixels[1]) + objcoords[obj1,3]=np.max(objallpixels[1]) + + objcounter = 0 # will build up a new watershed mask, have to run a counter because some objects lost + for obj1 in range(0,wsh.max()): #careful, object numbers -1 ! + print("Processing cell ",obj1+1," of ",wsh.max()," for oversegmentation.") + maskobj1 = dilobjs[obj1,:,:] + + if np.sum(maskobj1) > 0: #maskobj1 can be empty because in the loop, maskobj2 can be deleted if it is joined with a (previous) maskobj1 + objcounter = objcounter + 1 + maskoriobj1 = oriobjs[obj1,:,:] + + for obj2 in range(obj1+1,wsh.max()): + maskobj2 = dilobjs[obj2,:,:] + + if (np.sum(maskobj2) > 0 and #maskobj1 and 2 can be empty because joined with maskobj2's and then maskobj2's deleted (set to zero) + (((objcoords[obj1,0] - 2 < objcoords[obj2,0] and objcoords[obj1,1] + 2 > objcoords[obj2,0]) or # do the bounding boxes overlap? plus/minus 2 pixels to allow for bad bounding box measurement + (objcoords[obj2,0] - 2 < objcoords[obj1,0] and objcoords[obj2,1] + 2 > objcoords[obj1,0])) and + ((objcoords[obj1,2] - 2 < objcoords[obj2,2] and objcoords[obj1,3] + 2 > objcoords[obj2,2]) or + (objcoords[obj2,2] - 2 < objcoords[obj1,2] and objcoords[obj2,3] + 2 > objcoords[obj1,2])))): + border = maskobj1 * maskobj2 #intersection of two masks constitutes a border + # borderarea = np.sum(border) + # borderpred = border * pred + # borderheight = np.sum(borderpred) + + borderprednonzero = pred[np.nonzero(border)] # all the prediction values inside the border area + sortborderprednonzero = sorted(borderprednonzero) # sort the values + borderprednonzeroarea = len(borderprednonzero) # how many values are there? + quartborderarea = round(borderprednonzeroarea/4) # take one fourth of the values. there is some subtlety about how round() rounds but doesn't matter + topborderpred = sortborderprednonzero[quartborderarea:] # take top 3/4 of the predictions + topborderheight = np.sum(topborderpred) # sum over top 3/4 of the predictions + topborderarea = len(topborderpred) # area of 3/4 of predictions. In principle equal to 3/4 of borderprednonzeroarea but because of strange rounding, will just measure again + + if topborderarea > 8: # SJR: Not only must borderarea be greater than 0 but also have a little bit of border to go on. + #print(obj1+1, obj2+1, topborderheight/topborderarea) + if topborderheight/topborderarea > 0.99 : # SJR: We are really deep inside a cell, where the prediction is =1. Won't use: borderheight/borderarea > 0.95. Redundant. + #print("--") + #print(objcounter) + #wsh=np.where(wsh==obj2+1, obj1+1, wsh) + maskoriobj1 = np.uint8(np.multiply((maskoriobj1 > 0) | (oriobjs[obj2,:,:] > 0),1)) #have to do boolean then integer just to do an 'or' + dilobjs[obj1,:,:] = np.uint8(np.multiply((maskobj1 > 0) | (maskobj2 > 0),1)) #have to do boolean then integer just to do an 'or' + dilobjs[obj2,:,:] = np.zeros((wshshape[0],wshshape[1])) + objcoords[obj1,0] = min(objcoords[obj1,0],objcoords[obj2,0]) + objcoords[obj1,1] = max(objcoords[obj1,1],objcoords[obj2,1]) + objcoords[obj1,2] = min(objcoords[obj1,2],objcoords[obj2,2]) + objcoords[obj1,3] = max(objcoords[obj1,3],objcoords[obj2,3]) + + wshclean = wshclean + maskoriobj1*objcounter + #else: + # display(obj1+1,' no longer there.') + + return wshclean +