From 33188eb825c2c9b32d2c35591a643e42dc9b8991 Mon Sep 17 00:00:00 2001 From: mattminder <myfiles@Mattus-MacBook-Pro.local> Date: Mon, 4 May 2020 14:38:29 +0200 Subject: [PATCH] Compatibility for TFv1 and TFv2 --- unet/model.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/unet/model.py b/unet/model.py index 717700f..0733466 100644 --- a/unet/model.py +++ b/unet/model.py @@ -6,23 +6,38 @@ import os import skimage.io as io import skimage.transform as trans import numpy as np -from tensorflow.keras.models import * -from tensorflow.keras.layers import * -from tensorflow.keras.optimizers import * -from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler -from tensorflow.keras import backend as keras - - -from tensorflow import ConfigProto -from tensorflow import InteractiveSession -#from tensorflow.compat.v1 import Session -#from tensorflow.compat.v1 import get_default_graph -#from tensorflow.compat.v1 import S -#import tensorflow + +# Import tensorflow differently depending on version +from tensorflow import __version__ as tf_version +tf_version_old = int(tf_version[0]) <= 1 +if tf_version_old: + from tensorflow.keras.models import * + from tensorflow.keras.layers import * + from tensorflow.keras.optimizers import * + from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler + from tensorflow.keras import backend as keras + from tensorflow import ConfigProto + from tensorflow import InteractiveSession + +else: + from tensorflow.keras.models import * + from tensorflow.keras.layers import * + from tensorflow.keras.optimizers import * + from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler + from tensorflow.keras import backend as keras + from tensorflow.compat.v1 import ConfigProto + from tensorflow.compat.v1 import InteractiveSession + + config = ConfigProto() config.gpu_options.allow_growth = True session = InteractiveSession(config=config) +#import tensorflow +#config = ConfigProto() +#config.gpu_options.allow_growth = True +#session = InteractiveSession(config=config) + #session_conf = ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8, device_count = {'GPU':0}) #tensorflow.random.set_seed(1) #sess = Session(graph=get_default_graph(), config=session_conf) -- GitLab