From 33188eb825c2c9b32d2c35591a643e42dc9b8991 Mon Sep 17 00:00:00 2001
From: mattminder <myfiles@Mattus-MacBook-Pro.local>
Date: Mon, 4 May 2020 14:38:29 +0200
Subject: [PATCH] Compatibility for TFv1 and TFv2

---
 unet/model.py | 41 ++++++++++++++++++++++++++++-------------
 1 file changed, 28 insertions(+), 13 deletions(-)

diff --git a/unet/model.py b/unet/model.py
index 717700f..0733466 100644
--- a/unet/model.py
+++ b/unet/model.py
@@ -6,23 +6,38 @@ import os
 import skimage.io as io
 import skimage.transform as trans
 import numpy as np
-from tensorflow.keras.models import *
-from tensorflow.keras.layers import *
-from tensorflow.keras.optimizers import *
-from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
-from tensorflow.keras import backend as keras
-
-
-from tensorflow import ConfigProto
-from tensorflow import InteractiveSession
-#from tensorflow.compat.v1 import Session
-#from tensorflow.compat.v1 import get_default_graph
-#from tensorflow.compat.v1 import S
-#import tensorflow
+
+# Import tensorflow differently depending on version
+from tensorflow import __version__ as tf_version
+tf_version_old = int(tf_version[0]) <= 1
+if tf_version_old:
+    from tensorflow.keras.models import *
+    from tensorflow.keras.layers import *
+    from tensorflow.keras.optimizers import *
+    from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
+    from tensorflow.keras import backend as keras
+    from tensorflow import ConfigProto
+    from tensorflow import InteractiveSession
+
+else:
+    from tensorflow.keras.models import *
+    from tensorflow.keras.layers import *
+    from tensorflow.keras.optimizers import *
+    from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
+    from tensorflow.keras import backend as keras
+    from tensorflow.compat.v1 import ConfigProto
+    from tensorflow.compat.v1 import InteractiveSession
+    
+    
 config = ConfigProto()
 config.gpu_options.allow_growth = True
 session = InteractiveSession(config=config)
 
+#import tensorflow
+#config = ConfigProto()
+#config.gpu_options.allow_growth = True
+#session = InteractiveSession(config=config)
+
 #session_conf = ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8, device_count = {'GPU':0})
 #tensorflow.random.set_seed(1)
 #sess = Session(graph=get_default_graph(), config=session_conf)
-- 
GitLab