From d53e18b0fd6f0b71421d94d2c17f485b9227826f Mon Sep 17 00:00:00 2001
From: vermeul <swen@ethz.ch>
Date: Fri, 26 Aug 2022 17:05:43 +0200
Subject: [PATCH] feat: look in all parent dirs for local configs,
 get_token_for_hostname, automatic use of existing pats

---
 pybis/src/python/pybis/pybis.py | 130 ++++++++++++++++++++++++++++----
 1 file changed, 115 insertions(+), 15 deletions(-)

diff --git a/pybis/src/python/pybis/pybis.py b/pybis/src/python/pybis/pybis.py
index dec7d7db631..feef7b6ba1d 100644
--- a/pybis/src/python/pybis/pybis.py
+++ b/pybis/src/python/pybis/pybis.py
@@ -166,7 +166,7 @@ def read_config(config_filepath: Path) -> dict:
     if config_filepath.exists():
         with open(config_filepath, "r") as fh:
             config = json.load(fh)
-        return config
+    return config
 
 
 def get_global_config():
@@ -183,8 +183,12 @@ def set_global_config(hostname=None, token=None):
 
 
 def get_local_config():
-    config_filepath = Path.cwd() / CONFIG_FILENAME
-    config = read_config(config_filepath=config_filepath)
+    config = {}
+    for path in [Path.cwd(), *Path.cwd().parents]:
+        config_filepath = path / CONFIG_FILENAME
+        if config_filepath.exists() and config_filepath.is_file():
+            config = read_config(config_filepath=config_filepath)
+            break
     return config
 
 
@@ -196,14 +200,38 @@ def set_local_config(hostname=None, token=None):
 
 
 def get_saved_tokens():
-    tokens = defaultdict(list)
+    tokens = {}
     for filepath in PYBIS_FOLDER.glob("*.token"):
         with open(filepath) as fh:
             if filepath.is_file:
-                tokens[filepath.stem].append(fh.read())
+                token = fh.read()
+                tokens[filepath.stem] = token
     return tokens
 
 
+def get_token_for_hostname(hostname, session_token_needed=True):
+    """Searches for a stored token for a given host in this order:
+    cwd/.pybis.json
+    ~/.pybis/.pybis.json
+    ~/.pybis/hostname.token
+    """
+    for config in [get_local_config(), get_global_config()]:
+        if config.get("hostname") == hostname:
+            if session_token_needed:
+                if is_session_token(config.get("token")):
+                    return config.get("token")
+            else:
+                return config.get("token")
+    tokens = get_saved_tokens()
+    if hostname in tokens:
+        if session_token_needed:
+            if is_session_token(tokens[hostname]):
+                return tokens[hostname]
+        else:
+            return tokens[hostname]
+    return
+
+
 def save_pats_to_disk(hostname: str, resp: dict) -> None:
     pats = resp["objects"]
     parse_jackson(pats)
@@ -1001,6 +1029,17 @@ class Openbis:
         if not verify_certificates:
             urllib3.disable_warnings()
 
+        config_local = {}
+        config_global = {}
+        if url is None:
+            config_local = get_local_config()
+            if config_local:
+                url = config_local.get("hostname")
+            else:
+                config_global = get_global_config()
+                if config_global:
+                    url = config_global.get("hostname")
+
         if url is None:
             url = os.environ.get("OPENBIS_URL") or os.environ.get("OPENBIS_HOST")
             if url is None:
@@ -1038,12 +1077,23 @@ class Openbis:
             except:
                 pass
         else:
-            try:
-                token = self._get_saved_token()
-                self.token = token
-            except ValueError:
-                self._delete_saved_token()
-                pass
+            while True:
+                try:
+                    self.token = self._get_saved_token()
+                    break
+                except ValueError:
+                    self._delete_saved_token()
+                    pass
+                try:
+                    self.token = config_local.get("token")
+                    break
+                except ValueError:
+                    pass
+                try:
+                    self.token = config_global.get("token")
+                    break
+                except ValueError:
+                    pass
 
     def _get_username(self):
         if self.token:
@@ -1148,7 +1198,8 @@ class Openbis:
             "new_material_type()",
             "new_semantic_annotation()",
             "new_transaction()",
-            "create_personal_access_token()",
+            "new_personal_access_token()",
+            "renew_token()",
             "set_token()",
         ]
 
@@ -1989,7 +2040,7 @@ class Openbis:
             df_initializer=create_data_frame,
         )
 
-    def create_personal_access_token(
+    def new_personal_access_token(
         self, sessionName: str, validFrom: datetime = None, validTo: datetime = None
     ) -> str:
         """Creates a new personal access token (PAT)"""
@@ -1999,6 +2050,11 @@ class Openbis:
         if validTo is None:
             validTo = datetime.now() + relativedelta(years=1)
 
+        if is_personal_access_token(self.token):
+            raise ValueError(
+                "You you need a session token to create a new personal access token."
+            )
+
         entity = "personalAccessToken"
         request = {
             "method": get_method_for_entity(entity, "create"),
@@ -2025,6 +2081,46 @@ class Openbis:
             pass
             # if "error" in resp and resp["error"]["message"] == "method not found":
 
+    def renew_token(
+        self, username=None, password=None, hostname=None, save_token=False, token=None
+    ):
+        if token is None:
+            token = self.token
+
+        if hostname is None:
+            hostname = self.hostname
+
+        if is_session_token(token):
+            if self.is_token_valid(token):
+                # no need to renew a session token as it renews itself
+                return
+            else:
+                self.login(username=username, password=password, save_token=save_token)
+                return
+
+        session_token = get_token_for_hostname(hostname, session_token_needed=True)
+        try:
+            self.set_token(session_token)
+        except Exception:
+            self.login(username=username, password=password)
+
+        import pdb
+
+        pdb.set_trace()
+        session_info = self.get_session_info(token=token)
+        validFrom_orig = datetime.strptime(session_info.validFrom, "%Y-%m-%d %H:%H:%S")
+        validTo_orig = datetime.strptime(session_info.validTo, "%Y-%m-%d %H:%H:%S")
+        days_delta = abs(validFrom_orig - validTo_orig).days
+
+        new_pat = self.new_personal_access_token(
+            sessionName=session_info.sessionName,
+            validFrom=datetime.now(),
+            validTo=relativedelta(datetime.now(), days=days_delta),
+        )
+        self.set_token(new_pat)
+        if VERBOSE:
+            print(self.token)
+
     def get_personal_access_tokens(
         self,
         permId=None,
@@ -2067,7 +2163,7 @@ class Openbis:
             attrs = defs["attrs"]
             objects = response["objects"]
             if len(objects) == 0:
-                persons = DataFrame(columns=attrs)
+                pats = DataFrame(columns=attrs)
             else:
                 parse_jackson(objects)
 
@@ -4331,7 +4427,7 @@ class Openbis:
             return None
         return SessionInformation(openbis_obj=self, data=resp)
 
-    def set_token(self, token, save_token=False):
+    def set_token(self, token, save_token=False, save_local=False, save_global=False):
         """Checks the validity of a token, sets it as the current token and (by default) saves it
         to the disk, i.e. in the ~/.pybis directory
         """
@@ -4340,6 +4436,10 @@ class Openbis:
         if not self.is_token_valid(token):
             raise ValueError("Session is no longer valid. Please log in again.")
         else:
+            if self.token and is_session_token(self.token):
+                # if current token is a session token, save it in .session_token
+                # just in case we need it later
+                self.session_token = self.token
             self.__dict__["token"] = token
         if save_token:
             self._save_token_to_disk()
-- 
GitLab