Skip to content
Snippets Groups Projects
Commit c5097d11 authored by mattminder's avatar mattminder
Browse files

Warn when no weights

parent 54cc88c7
No related branches found
No related tags found
No related merge requests found
......@@ -708,21 +708,30 @@ class App(QMainWindow):
def PredThreshSeg(self, timeindex, fovindex, thr_val, seg_val,
is_pc):
"""
This function is called in the LaunchBatchPrediction function.
This function calls the neural network function in the
InteractionDisk.py file and then thresholds the result
of the prediction, saves this thresholded prediction.
Then it segments the thresholded prediction and saves the
segmentation.
"""
print('--------- Segmenting field of view:',fovindex,'Time point:',timeindex)
im = self.reader.LoadOneImage(timeindex, fovindex)
pred = self.LaunchPrediction(im, is_pc)
thresh = self.ThresholdPred(thr_val, pred)
seg = segment(thresh, pred, seg_val)
self.reader.SaveMask(timeindex, fovindex, seg)
print('--------- Finished segmenting.')
"""
This function is called in the LaunchBatchPrediction function.
This function calls the neural network function in the
InteractionDisk.py file and then thresholds the result
of the prediction, saves this thresholded prediction.
Then it segments the thresholded prediction and saves the
segmentation.
"""
print('--------- Segmenting field of view:',fovindex,'Time point:',timeindex)
im = self.reader.LoadOneImage(timeindex, fovindex)
try:
pred = self.LaunchPrediction(im, is_pc)
except ValueError:
QMessageBox.critical(self, 'Error',
'The neural network weight files could not '
'be found. Make sure to download them from '
'the link in the readme and put them into '
'the folder unet')
return
thresh = self.ThresholdPred(thr_val, pred)
seg = segment(thresh, pred, seg_val)
self.reader.SaveMask(timeindex, fovindex, seg)
print('--------- Finished segmenting.')
def LaunchPrediction(self, im, is_pc):
......@@ -731,7 +740,7 @@ class App(QMainWindow):
"""
im = skimage.exposure.equalize_adapthist(im)
im = im*1.0;
pred = nn.prediction(im, is_pc)
pred = nn.prediction(im, is_pc)
return pred
......
......@@ -59,9 +59,14 @@ def prediction(im, is_pc):
input_size = (None,None,1))
if is_pc:
model.load_weights('unet/unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5')
path = 'unet/unet_weights_batchsize_25_Nepochs_100_SJR0_10.hdf5'
else:
model.load_weights('unet/unet_weights_BF_batchsize_10_Nepochs_100_SJR_0_1.hdf5')
path = 'unet/unet_weights_BF_batchsize_10_Nepochs_100_SJR_0_1.hdf5'
if not os.path.exists(path):
raise ValueError('Path does not exist')
model.load_weights(path)
results = model.predict(padded[np.newaxis,:,:,np.newaxis], batch_size=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment