From 46b3ec080fa2815eac50dfdabc8dd96d922508f9 Mon Sep 17 00:00:00 2001 From: alaskowski <alaskowski@ethz.ch> Date: Fri, 30 Jun 2023 08:43:29 +0200 Subject: [PATCH] SSDM-13796: Fixes to fastdownload implementation. --- .../src/python/pybis/fast_download.py | 160 +++++++++------ .../src/python/tests/test_fastdownload.py | 194 ++++++++++++++++-- 2 files changed, 282 insertions(+), 72 deletions(-) diff --git a/api-openbis-python3-pybis/src/python/pybis/fast_download.py b/api-openbis-python3-pybis/src/python/pybis/fast_download.py index b907bb91a3d..074c23d72a5 100644 --- a/api-openbis-python3-pybis/src/python/pybis/fast_download.py +++ b/api-openbis-python3-pybis/src/python/pybis/fast_download.py @@ -29,6 +29,7 @@ import binascii import functools import json import os +import time from pathlib import Path from threading import Lock, Thread from urllib.parse import urljoin @@ -115,6 +116,11 @@ def deserialize_chunk(byte_array): 'invalid_reason': "" } + if len(byte_array) == 0: + result['invalid'] = True + result['invalid_reason'] = "HEADER" + return result + start, end = 0, sequence_number_bytes result['sequence_number'] = int.from_bytes(byte_array[start:end], "big") start, end = end, end + download_item_id_length_bytes @@ -173,6 +179,10 @@ class AtomicChecker: with self._lock: self._max += 1 + def break_count(self): + with self._lock: + self._max = 0 + def remove_value(self, value): with self._lock: if value in self._set: @@ -182,6 +192,13 @@ class AtomicChecker: return self._set +def _get_json(response): + try: + return True, response.json() + except: + return False, response + + class DownloadThread(Thread): """Helper class defining single stream download""" @@ -206,6 +223,7 @@ class DownloadThread(Thread): downloadSessionId=self.download_session_id, numberOfChunks=self.number_of_chunks, downloadStreamId=self.stream_id) + retry_counter = 0 while self.counter.should_continue(): try: download_response = self.session.post(self.download_url, @@ -214,20 +232,36 @@ class DownloadThread(Thread): if download_response.ok is True: data = deserialize_chunk(download_response.content) if data['invalid'] is True: - print(f"Invalid checksum received. Retrying package") - if data['invalid_reason'] == "PAYLOAD": - sequence_number = data['sequence_number'] - if repeated_chunks.get(sequence_number, 0) >= DOWNLOAD_RETRIES_COUNT: - raise ValueError( - "Received incorrect payload multiple times. Aborting.") - repeated_chunks[sequence_number] = repeated_chunks.get(sequence_number, - 0) + 1 - queue_chunks(self.session, self.download_url, - self.download_session_id, - [f"{sequence_number}:{sequence_number}"], - self.verify_certificates) - self.counter.repeat_call() # queue additional download chunk run + is_json, response = _get_json(download_response) + if is_json: + if 'retriable' in response and response['retriable'] is False: + self.counter.break_count() + raise ValueError(response["error"]) + else: + if data['invalid_reason'] == "PAYLOAD": + sequence_number = data['sequence_number'] + if repeated_chunks.get(sequence_number, 0) >= DOWNLOAD_RETRIES_COUNT: + self.counter.break_count() + raise ValueError( + "Received incorrect payload multiple times. Aborting.") + repeated_chunks[sequence_number] = repeated_chunks.get(sequence_number, + 0) + 1 + queue_chunks(self.session, self.download_url, + self.download_session_id, + [f"{sequence_number}:{sequence_number}"], + self.verify_certificates) + self.counter.repeat_call() # queue additional download chunk run + + if retry_counter >= REQUEST_RETRIES_COUNT: + self.counter.break_count() + raise ValueError("Consecutive download calls to the server failed.") + + # Exponential backoff for the consecutive failures + time.sleep(2 ** retry_counter) + retry_counter += 1 + else: + retry_counter = 0 sequence_number = data['sequence_number'] self.save_to_file(data) self.counter.remove_value(sequence_number) @@ -322,35 +356,26 @@ class FastDownload: start_session_params) download_session_id = start_download_session['downloadSessionId'] - try: - # Step 3 - Put files into fileserver download queue - ranges = start_download_session['ranges'] - self._queue_all_files(download_url, download_session_id, ranges) + # Step 3 - Put files into fileserver download queue - # Step 4 - Download files in chunks + ranges = start_download_session['ranges'] + self._queue_all_files(download_url, download_session_id, ranges) - session_stream_ids = list(start_download_session['streamIds']) + # Step 4 & 5 - Download files in chunks and close connection - exception_list = [] - thread = Thread(target=self._download_step, - args=(download_url, download_session_id, session_stream_ids, ranges, - exception_list)) - thread.start() + session_stream_ids = list(start_download_session['streamIds']) - if self.wait_until_finished is True: - thread.join() - if exception_list: - raise exception_list[0] - finally: - # Step 5 - Close the session - finish_download_session_params = make_fileserver_body_params( - method='finishDownloadSession', - downloadSessionId=download_session_id) + exception_list = [] + thread = Thread(target=self._download_step, + args=(download_url, download_session_id, session_stream_ids, ranges, + exception_list)) + thread.start() - self.session.post(download_url, - data=json.dumps(finish_download_session_params), - verify=self.verify_certificates) + if self.wait_until_finished is True: + thread.join() + if exception_list: + raise exception_list[0] return self.destination @@ -408,28 +433,43 @@ class FastDownload: chunks_to_download = set(range(min_chunk, max_chunk + 1)) counter = 1 - while True: # each iteration will create threads for streams - checker = AtomicChecker(chunks_to_download) - streams = [ - DownloadThread(self.session, download_url, download_session_id, stream_id, checker, - self.verify_certificates, self.create_default_folders, - self.destination) for stream_id in session_stream_ids] - - for thread in streams: - thread.start() - for thread in streams: - thread.join() - - if chunks_to_download == set(): # if there are no more chunks to download - break - else: - if counter >= DOWNLOAD_RETRIES_COUNT: - print(f"Reached maximum retry count:{counter}. Aborting.") - exception_list += [ - ValueError(f"Reached maximum retry count:{counter}. Aborting.")] + try: + while True: # each iteration will create threads for streams + checker = AtomicChecker(chunks_to_download) + streams = [ + DownloadThread(self.session, download_url, download_session_id, stream_id, checker, + self.verify_certificates, self.create_default_folders, + self.destination) for stream_id in session_stream_ids] + + for thread in streams: + thread.start() + for thread in streams: + thread.join() + + if chunks_to_download == set(): # if there are no more chunks to download break - counter += 1 - # queue chunks that we - queue_chunks(self.session, download_url, download_session_id, - [f"{x}:{x}" for x in chunks_to_download], - self.verify_certificates) + else: + if counter >= DOWNLOAD_RETRIES_COUNT: + print(f"Reached maximum retry count:{counter}. Aborting.") + exception_list += [ + ValueError(f"Reached maximum retry count:{counter}. Aborting.")] + break + exceptions = [stream.exc for stream in streams if stream.exc is not None] + if exceptions: + print(f"Download failed with message: {exceptions[0]}") + exception_list += exceptions + break + counter += 1 + # queue chunks that failed to download in the previous pass + queue_chunks(self.session, download_url, download_session_id, + [f"{x}:{x}" for x in chunks_to_download], + self.verify_certificates) + finally: + # Step 5 - Close the session + finish_download_session_params = make_fileserver_body_params( + method='finishDownloadSession', + downloadSessionId=download_session_id) + + self.session.post(download_url, + data=json.dumps(finish_download_session_params), + verify=self.verify_certificates) diff --git a/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py b/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py index 9ad2bbb63cf..0ada560b2e7 100644 --- a/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py +++ b/api-openbis-python3-pybis/src/python/tests/test_fastdownload.py @@ -1,15 +1,17 @@ +import binascii import json import os +import time from http.server import BaseHTTPRequestHandler, HTTPServer from threading import Thread import pytest + from pybis.fast_download import FastDownload def get_download_response(sequence_number, perm_id, file, is_directory, offset, payload): - # binascii.crc32(byte_array[:end]) - import binascii + result = b'' result += sequence_number.to_bytes(4, "big") download_item_id = perm_id + "/" + file @@ -52,7 +54,7 @@ class MyServer(BaseHTTPRequestHandler): self.wfile.write(response) -def createFastDownloadSession(permId, files, download_url, wished_number_of_streams): +def create_fast_download_session(permId, files, download_url, wished_number_of_streams): return '''{ "jsonrpc": "2.0", "id": "2", "result": { "@type": "dss.dto.datasetfile.fastdownload.FastDownloadSession", "@id": 1, "downloadUrl": "''' + download_url + '''", @@ -66,7 +68,7 @@ def createFastDownloadSession(permId, files, download_url, wished_number_of_stre "wishedNumberOfStreams": ''' + wished_number_of_streams + ''' } } }''' -def startDownloadSession(ranges, wished_number_of_streams): +def start_download_session(ranges, wished_number_of_streams): return """{ "downloadSessionId": "72863f8d-1ed1-4795-a531-4d93a5081562", "ranges": { @@ -113,18 +115,18 @@ def run_around_tests(base_data): 'finishDownloadSession': "", 'counter': 0, 'parts': 10, - 'createFastDownloadSession': createFastDownloadSession(perm_id, - file, - download_url, - streams), - 'startDownloadSession': startDownloadSession(ranges, streams) + 'createFastDownloadSession': create_fast_download_session(perm_id, + file, + download_url, + streams), + 'startDownloadSession': start_download_session(ranges, streams) } MyServer.response_code = 200 yield temp_folder, download_url, streams, perm_id, file cleanup(temp_folder) -def test_download_fails_after_retry(run_around_tests): +def test_download_fails_after_retries(run_around_tests): temp_folder, download_url, streams, perm_id, file = run_around_tests def generate_download_response(): @@ -139,7 +141,7 @@ def test_download_fails_after_retry(run_around_tests): fast_download.download() assert False except ValueError as error: - assert str(error) == 'Reached maximum retry count:3. Aborting.' + assert str(error) == 'Consecutive download calls to the server failed.' def test_download_file(run_around_tests): @@ -179,6 +181,58 @@ def test_download_file(run_around_tests): assert expected_outcome == data +def test_download_file_wait_flag_disabled(run_around_tests): + temp_folder, download_url, streams, perm_id, file = run_around_tests + + def generate_download_response(): + parts = MyServer.next_response['parts'] + counter = MyServer.next_response['counter'] + payload_length = 10 + while counter < parts: + response = get_download_response(counter, perm_id, file, False, + counter * payload_length, + bytearray([counter] * payload_length)) + # Slow down responses to simulate download of a big file + time.sleep(0.1) + counter += 1 + MyServer.next_response['counter'] = counter % parts + yield response + + MyServer.next_response['download'] = generate_download_response() + + fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder), + True, False, False, streams) + fast_download.download() + + # Verify that file has not been downloaded yet + downloaded_files = [ + os.path.join(dp, f) + for dp, dn, fn in os.walk(temp_folder) + for f in fn + ] + assert len(downloaded_files) == 0 + + # Wait for 2 seconds to finish download + time.sleep(2) + + # find file + downloaded_files = [ + os.path.join(dp, f) + for dp, dn, fn in os.walk(temp_folder) + for f in fn + ] + assert len(downloaded_files) == 1 + + assert downloaded_files[0].endswith(file) + import functools + expected_outcome = functools.reduce(lambda a, b: a + b, + [bytearray([x] * 10) for x in range(10)]) + with open(downloaded_files[0], 'rb') as fn: + data = fn.read() + assert len(data) == 100 + assert expected_outcome == data + + def test_download_file_starts_with_fail(run_around_tests): temp_folder, download_url, streams, perm_id, file = run_around_tests @@ -217,4 +271,120 @@ def test_download_file_starts_with_fail(run_around_tests): with open(downloaded_files[0], 'rb') as fn: data = fn.read() assert len(data) == 100 - assert expected_outcome == data \ No newline at end of file + assert expected_outcome == data + + +def test_download_fails_after_getting_java_exception(run_around_tests): + """ + Test that verifies that if non-retryable exception is thrown, + the whole download session is aborted. + """ + temp_folder, download_url, streams, perm_id, file = run_around_tests + + def generate_download_response(): + # First download fails with non-retryable exception + yield b'{"error":"Some server error message.","retriable":false}' + # Further responses are alright + MyServer.response_code = 200 + parts = MyServer.next_response['parts'] + counter = MyServer.next_response['counter'] + payload_length = 10 + while counter < parts: + response = get_download_response(counter, perm_id, file, False, + counter * payload_length, + bytearray([counter] * payload_length)) + counter += 1 + MyServer.next_response['counter'] = counter % parts + yield response + + MyServer.next_response['download'] = generate_download_response() + + fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder), + True, True, False, streams) + try: + fast_download.download() + assert False + except ValueError as error: + assert str(error) == 'Some server error message.' + + +def test_download_passes_after_getting_java_exception(run_around_tests): + """ + Test that verifies that if retryable server exception is thrown, + the whole download retries and downloads the file. + """ + temp_folder, download_url, streams, perm_id, file = run_around_tests + + def generate_download_response(): + # First download fails with non-retryable exception + yield b'{"error":"Some server error message.","retriable":true}' + # Further responses are alright + MyServer.response_code = 200 + parts = MyServer.next_response['parts'] + counter = MyServer.next_response['counter'] + payload_length = 10 + while counter < parts: + response = get_download_response(counter, perm_id, file, False, + counter * payload_length, + bytearray([counter] * payload_length)) + counter += 1 + MyServer.next_response['counter'] = counter % parts + yield response + + MyServer.next_response['download'] = generate_download_response() + + fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder), + True, True, False, streams) + fast_download.download() + + downloaded_files = [ + os.path.join(dp, f) + for dp, dn, fn in os.walk(temp_folder) + for f in fn + ] + assert len(downloaded_files) == 1 + assert downloaded_files[0].endswith(file) + import functools + expected_outcome = functools.reduce(lambda a, b: a + b, + [bytearray([x] * 10) for x in range(10)]) + with open(downloaded_files[0], 'rb') as fn: + data = fn.read() + assert len(data) == 100 + assert expected_outcome == data + + +def test_download_file_payload_failure(run_around_tests): + temp_folder, download_url, streams, perm_id, file = run_around_tests + + def generate_download_response(): + parts = MyServer.next_response['parts'] + counter = MyServer.next_response['counter'] + payload_length = 10 + + fail_response = None + while counter < parts: + response = get_download_response(counter, perm_id, file, False, + counter * payload_length, + bytearray([counter] * payload_length)) + if counter == 0: + array = bytearray(response) + array[-8:] = bytearray([0]*8) + response = bytes(array) + fail_response = response + + counter += 1 + MyServer.next_response['counter'] = counter % parts + yield response + + while True: + yield fail_response + + MyServer.next_response['download'] = generate_download_response() + + fast_download = FastDownload("", download_url, perm_id, file, str(temp_folder), + True, True, False, streams) + try: + fast_download.download() + assert False + except ValueError as error: + assert str(error) == 'Received incorrect payload multiple times. Aborting.' -- GitLab