From 46b3ec080fa2815eac50dfdabc8dd96d922508f9 Mon Sep 17 00:00:00 2001
From: alaskowski <alaskowski@ethz.ch>
Date: Fri, 30 Jun 2023 08:43:29 +0200
Subject: [PATCH] SSDM-13796: Fixes to fastdownload implementation.

---
 .../src/python/pybis/fast_download.py         | 160 +++++++++------
 .../src/python/tests/test_fastdownload.py     | 194 ++++++++++++++++--
 2 files changed, 282 insertions(+), 72 deletions(-)

diff --git a/api-openbis-python3-pybis/src/python/pybis/fast_download.py b/api-openbis-python3-pybis/src/python/pybis/fast_download.py
index b907bb91a3d..074c23d72a5 100644
--- a/api-openbis-python3-pybis/src/python/pybis/fast_download.py
+++ b/api-openbis-python3-pybis/src/python/pybis/fast_download.py
@@ -29,6 +29,7 @@ import binascii
 import functools
 import json
 import os
+import time
 from pathlib import Path
 from threading import Lock, Thread
 from urllib.parse import urljoin
@@ -115,6 +116,11 @@ def deserialize_chunk(byte_array):
         'invalid_reason': ""
     }
 
+    if len(byte_array) == 0:
+        result['invalid'] = True
+        result['invalid_reason'] = "HEADER"
+        return result
+
     start, end = 0, sequence_number_bytes
     result['sequence_number'] = int.from_bytes(byte_array[start:end], "big")
     start, end = end, end + download_item_id_length_bytes
@@ -173,6 +179,10 @@ class AtomicChecker:
         with self._lock:
             self._max += 1
 
+    def break_count(self):
+        with self._lock:
+            self._max = 0
+
     def remove_value(self, value):
         with self._lock:
             if value in self._set:
@@ -182,6 +192,13 @@ class AtomicChecker:
         return self._set
 
 
+def _get_json(response):
+    try:
+        return True, response.json()
+    except:
+        return False, response
+
+
 class DownloadThread(Thread):
     """Helper class defining single stream download"""
 
@@ -206,6 +223,7 @@ class DownloadThread(Thread):
                                                       downloadSessionId=self.download_session_id,
                                                       numberOfChunks=self.number_of_chunks,
                                                       downloadStreamId=self.stream_id)
+        retry_counter = 0
         while self.counter.should_continue():
             try:
                 download_response = self.session.post(self.download_url,
@@ -214,20 +232,36 @@ class DownloadThread(Thread):
                 if download_response.ok is True:
                     data = deserialize_chunk(download_response.content)
                     if data['invalid'] is True:
-                        print(f"Invalid checksum received. Retrying package")
-                        if data['invalid_reason'] == "PAYLOAD":
-                            sequence_number = data['sequence_number']
-                            if repeated_chunks.get(sequence_number, 0) >= DOWNLOAD_RETRIES_COUNT:
-                                raise ValueError(
-                                    "Received incorrect payload multiple times. Aborting.")
-                            repeated_chunks[sequence_number] = repeated_chunks.get(sequence_number,
-                                                                                   0) + 1
-                            queue_chunks(self.session, self.download_url,
-                                         self.download_session_id,
-                                         [f"{sequence_number}:{sequence_number}"],
-                                         self.verify_certificates)
-                            self.counter.repeat_call()  # queue additional download chunk run
+                        is_json, response = _get_json(download_response)
+                        if is_json:
+                            if 'retriable' in response and response['retriable'] is False:
+                                self.counter.break_count()
+                                raise ValueError(response["error"])
+                        else:
+                            if data['invalid_reason'] == "PAYLOAD":
+                                sequence_number = data['sequence_number']
+                                if repeated_chunks.get(sequence_number, 0) >= DOWNLOAD_RETRIES_COUNT:
+                                    self.counter.break_count()
+                                    raise ValueError(
+                                        "Received incorrect payload multiple times. Aborting.")
+                                repeated_chunks[sequence_number] = repeated_chunks.get(sequence_number,
+                                                                                       0) + 1
+                                queue_chunks(self.session, self.download_url,
+                                             self.download_session_id,
+                                             [f"{sequence_number}:{sequence_number}"],
+                                             self.verify_certificates)
+                                self.counter.repeat_call()  # queue additional download chunk run
+
+                        if retry_counter >= REQUEST_RETRIES_COUNT:
+                            self.counter.break_count()
+                            raise ValueError("Consecutive download calls to the server failed.")
+
+                        # Exponential backoff for the consecutive failures
+                        time.sleep(2 ** retry_counter)
+                        retry_counter += 1
+
                     else:
+                        retry_counter = 0
                         sequence_number = data['sequence_number']
                         self.save_to_file(data)
                         self.counter.remove_value(sequence_number)
@@ -322,35 +356,26 @@ class FastDownload:
                                               start_session_params)
         download_session_id = start_download_session['downloadSessionId']
 
-        try:
-            # Step 3 - Put files into fileserver download queue
 
-            ranges = start_download_session['ranges']
-            self._queue_all_files(download_url, download_session_id, ranges)
+        # Step 3 - Put files into fileserver download queue
 
-            # Step 4 - Download files in chunks
+        ranges = start_download_session['ranges']
+        self._queue_all_files(download_url, download_session_id, ranges)
 
-            session_stream_ids = list(start_download_session['streamIds'])
+        # Step 4 & 5 - Download files in chunks and close connection
 
-            exception_list = []
-            thread = Thread(target=self._download_step,
-                            args=(download_url, download_session_id, session_stream_ids, ranges,
-                                  exception_list))
-            thread.start()
+        session_stream_ids = list(start_download_session['streamIds'])
 
-            if self.wait_until_finished is True:
-                thread.join()
-                if exception_list:
-                    raise exception_list[0]
-        finally:
-            # Step 5 - Close the session
-            finish_download_session_params = make_fileserver_body_params(
-                method='finishDownloadSession',
-                downloadSessionId=download_session_id)
+        exception_list = []
+        thread = Thread(target=self._download_step,
+                        args=(download_url, download_session_id, session_stream_ids, ranges,
+                              exception_list))
+        thread.start()
 
-            self.session.post(download_url,
-                              data=json.dumps(finish_download_session_params),
-                              verify=self.verify_certificates)
+        if self.wait_until_finished is True:
+            thread.join()
+            if exception_list:
+                raise exception_list[0]
 
         return self.destination
 
@@ -408,28 +433,43 @@ class FastDownload:
         chunks_to_download = set(range(min_chunk, max_chunk + 1))
 
         counter = 1
-        while True:  # each iteration will create threads for streams
-            checker = AtomicChecker(chunks_to_download)
-            streams = [
-                DownloadThread(self.session, download_url, download_session_id, stream_id, checker,
-                               self.verify_certificates, self.create_default_folders,
-                               self.destination) for stream_id in session_stream_ids]
-
-            for thread in streams:
-                thread.start()
-            for thread in streams:
-                thread.join()
-
-            if chunks_to_download == set():  # if there are no more chunks to download
-                break
-            else:
-                if counter >= DOWNLOAD_RETRIES_COUNT:
-                    print(f"Reached maximum retry count:{counter}. Aborting.")
-                    exception_list += [
-                        ValueError(f"Reached maximum retry count:{counter}. Aborting.")]
+        try:
+            while True:  # each iteration will create threads for streams
+                checker = AtomicChecker(chunks_to_download)
+                streams = [
+                    DownloadThread(self.session, download_url, download_session_id, stream_id, checker,
+                                   self.verify_certificates, self.create_default_folders,
+                                   self.destination) for stream_id in session_stream_ids]
+
+                for thread in streams:
+                    thread.start()
+                for thread in streams:
+                    thread.join()
+
+                if chunks_to_download == set():  # if there are no more chunks to download
                     break
-                counter += 1
-                # queue chunks that we
-                queue_chunks(self.session, download_url, download_session_id,
-                             [f"{x}:{x}" for x in chunks_to_download],
-                             self.verify_certificates)
+                else:
+                    if counter >= DOWNLOAD_RETRIES_COUNT:
+                        print(f"Reached maximum retry count:{counter}. Aborting.")
+                        exception_list += [
+                            ValueError(f"Reached maximum retry count:{counter}. Aborting.")]
+                        break
+                    exceptions = [stream.exc for stream in streams if stream.exc is not None]
+                    if exceptions:
+                        print(f"Download failed with message: {exceptions[0]}")
+                        exception_list += exceptions
+                        break
+                    counter += 1
+                    # queue chunks that failed to download in the previous pass
+                    queue_chunks(self.session, download_url, download_session_id,
+                                 [f"{x}:{x}" for x in chunks_to_download],
+                                 self.verify_certificates)
+        finally:
+            # Step 5 - Close the session
+            finish_download_session_params = make_fileserver_body_params(
+                method='finishDownloadSession',
+                downloadSessionId=download_session_id)
+
+            self.session.post(download_url,
+                              data=json.dumps(finish_download_session_params),
+                              verify=self.verify_certificates)
diff --git a/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py b/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py
index 9ad2bbb63cf..0ada560b2e7 100644
--- a/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py
+++ b/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py
@@ -1,15 +1,17 @@
+import binascii
 import json
 import os
+import time
 from http.server import BaseHTTPRequestHandler, HTTPServer
 from threading import Thread
 
 import pytest
+
 from pybis.fast_download import FastDownload
 
 
 def get_download_response(sequence_number, perm_id, file, is_directory, offset, payload):
-    # binascii.crc32(byte_array[:end])
-    import binascii
+
     result = b''
     result += sequence_number.to_bytes(4, "big")
     download_item_id = perm_id + "/" + file
@@ -52,7 +54,7 @@ class MyServer(BaseHTTPRequestHandler):
         self.wfile.write(response)
 
 
-def createFastDownloadSession(permId, files, download_url, wished_number_of_streams):
+def create_fast_download_session(permId, files, download_url, wished_number_of_streams):
     return '''{ "jsonrpc": "2.0", "id": "2", "result": {
         "@type": "dss.dto.datasetfile.fastdownload.FastDownloadSession", "@id": 1,
         "downloadUrl": "''' + download_url + '''",
@@ -66,7 +68,7 @@ def createFastDownloadSession(permId, files, download_url, wished_number_of_stre
             "wishedNumberOfStreams": ''' + wished_number_of_streams + ''' } } }'''
 
 
-def startDownloadSession(ranges, wished_number_of_streams):
+def start_download_session(ranges, wished_number_of_streams):
     return """{
         "downloadSessionId": "72863f8d-1ed1-4795-a531-4d93a5081562",
         "ranges": {
@@ -113,18 +115,18 @@ def run_around_tests(base_data):
         'finishDownloadSession': "",
         'counter': 0,
         'parts': 10,
-        'createFastDownloadSession': createFastDownloadSession(perm_id,
-                                                               file,
-                                                               download_url,
-                                                               streams),
-        'startDownloadSession': startDownloadSession(ranges, streams)
+        'createFastDownloadSession': create_fast_download_session(perm_id,
+                                                                  file,
+                                                                  download_url,
+                                                                  streams),
+        'startDownloadSession': start_download_session(ranges, streams)
     }
     MyServer.response_code = 200
     yield temp_folder, download_url, streams, perm_id, file
     cleanup(temp_folder)
 
 
-def test_download_fails_after_retry(run_around_tests):
+def test_download_fails_after_retries(run_around_tests):
     temp_folder, download_url, streams, perm_id, file = run_around_tests
 
     def generate_download_response():
@@ -139,7 +141,7 @@ def test_download_fails_after_retry(run_around_tests):
         fast_download.download()
         assert False
     except ValueError as error:
-        assert str(error) == 'Reached maximum retry count:3. Aborting.'
+        assert str(error) == 'Consecutive download calls to the server failed.'
 
 
 def test_download_file(run_around_tests):
@@ -179,6 +181,58 @@ def test_download_file(run_around_tests):
         assert expected_outcome == data
 
 
+def test_download_file_wait_flag_disabled(run_around_tests):
+    temp_folder, download_url, streams, perm_id, file = run_around_tests
+
+    def generate_download_response():
+        parts = MyServer.next_response['parts']
+        counter = MyServer.next_response['counter']
+        payload_length = 10
+        while counter < parts:
+            response = get_download_response(counter, perm_id, file, False,
+                                             counter * payload_length,
+                                             bytearray([counter] * payload_length))
+            # Slow down responses to simulate download of a big file
+            time.sleep(0.1)
+            counter += 1
+            MyServer.next_response['counter'] = counter % parts
+            yield response
+
+    MyServer.next_response['download'] = generate_download_response()
+
+    fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder),
+                                 True, False, False, streams)
+    fast_download.download()
+
+    # Verify that file has not been downloaded yet
+    downloaded_files = [
+        os.path.join(dp, f)
+        for dp, dn, fn in os.walk(temp_folder)
+        for f in fn
+    ]
+    assert len(downloaded_files) == 0
+
+    # Wait for 2 seconds to finish download
+    time.sleep(2)
+
+    # find file
+    downloaded_files = [
+        os.path.join(dp, f)
+        for dp, dn, fn in os.walk(temp_folder)
+        for f in fn
+    ]
+    assert len(downloaded_files) == 1
+
+    assert downloaded_files[0].endswith(file)
+    import functools
+    expected_outcome = functools.reduce(lambda a, b: a + b,
+                                        [bytearray([x] * 10) for x in range(10)])
+    with open(downloaded_files[0], 'rb') as fn:
+        data = fn.read()
+        assert len(data) == 100
+        assert expected_outcome == data
+
+
 def test_download_file_starts_with_fail(run_around_tests):
     temp_folder, download_url, streams, perm_id, file = run_around_tests
 
@@ -217,4 +271,120 @@ def test_download_file_starts_with_fail(run_around_tests):
     with open(downloaded_files[0], 'rb') as fn:
         data = fn.read()
         assert len(data) == 100
-        assert expected_outcome == data
\ No newline at end of file
+        assert expected_outcome == data
+
+
+def test_download_fails_after_getting_java_exception(run_around_tests):
+    """
+    Test that verifies that if non-retryable exception is thrown,
+    the whole download session is aborted.
+    """
+    temp_folder, download_url, streams, perm_id, file = run_around_tests
+
+    def generate_download_response():
+        # First download fails with non-retryable exception
+        yield b'{"error":"Some server error message.","retriable":false}'
+        # Further responses are alright
+        MyServer.response_code = 200
+        parts = MyServer.next_response['parts']
+        counter = MyServer.next_response['counter']
+        payload_length = 10
+        while counter < parts:
+            response = get_download_response(counter, perm_id, file, False,
+                                             counter * payload_length,
+                                             bytearray([counter] * payload_length))
+            counter += 1
+            MyServer.next_response['counter'] = counter % parts
+            yield response
+
+    MyServer.next_response['download'] = generate_download_response()
+
+    fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder),
+                                 True, True, False, streams)
+    try:
+        fast_download.download()
+        assert False
+    except ValueError as error:
+        assert str(error) == 'Some server error message.'
+
+
+def test_download_passes_after_getting_java_exception(run_around_tests):
+    """
+    Test that verifies that if retryable server exception is thrown,
+    the whole download retries and downloads the file.
+    """
+    temp_folder, download_url, streams, perm_id, file = run_around_tests
+
+    def generate_download_response():
+        # First download fails with non-retryable exception
+        yield b'{"error":"Some server error message.","retriable":true}'
+        # Further responses are alright
+        MyServer.response_code = 200
+        parts = MyServer.next_response['parts']
+        counter = MyServer.next_response['counter']
+        payload_length = 10
+        while counter < parts:
+            response = get_download_response(counter, perm_id, file, False,
+                                             counter * payload_length,
+                                             bytearray([counter] * payload_length))
+            counter += 1
+            MyServer.next_response['counter'] = counter % parts
+            yield response
+
+    MyServer.next_response['download'] = generate_download_response()
+
+    fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder),
+                                 True, True, False, streams)
+    fast_download.download()
+
+    downloaded_files = [
+        os.path.join(dp, f)
+        for dp, dn, fn in os.walk(temp_folder)
+        for f in fn
+    ]
+    assert len(downloaded_files) == 1
+    assert downloaded_files[0].endswith(file)
+    import functools
+    expected_outcome = functools.reduce(lambda a, b: a + b,
+                                        [bytearray([x] * 10) for x in range(10)])
+    with open(downloaded_files[0], 'rb') as fn:
+        data = fn.read()
+        assert len(data) == 100
+        assert expected_outcome == data
+
+
+def test_download_file_payload_failure(run_around_tests):
+    temp_folder, download_url, streams, perm_id, file = run_around_tests
+
+    def generate_download_response():
+        parts = MyServer.next_response['parts']
+        counter = MyServer.next_response['counter']
+        payload_length = 10
+
+        fail_response = None
+        while counter < parts:
+            response = get_download_response(counter, perm_id, file, False,
+                                             counter * payload_length,
+                                             bytearray([counter] * payload_length))
+            if counter == 0:
+                array = bytearray(response)
+                array[-8:] = bytearray([0]*8)
+                response = bytes(array)
+                fail_response = response
+
+            counter += 1
+            MyServer.next_response['counter'] = counter % parts
+            yield response
+
+        while True:
+            yield fail_response
+
+    MyServer.next_response['download'] = generate_download_response()
+
+    fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder),
+                                 True, True, False, streams)
+    try:
+        fast_download.download()
+        assert False
+    except ValueError as error:
+        assert str(error) == 'Received incorrect payload multiple times. Aborting.'
-- 
GitLab