From 6ae27f3046150d72c1f68d338d63bfb0d1b9aea1 Mon Sep 17 00:00:00 2001
From: Swen Vermeul <swen@ethz.ch>
Date: Fri, 10 Jun 2016 12:34:43 +0200
Subject: [PATCH] parallel download of files. Number of workers can be
 specified, no wait until download is complete

---
 src/python/PyBis/pybis/pybis.py | 74 ++++++++++++++++++++++++++-------
 1 file changed, 59 insertions(+), 15 deletions(-)

diff --git a/src/python/PyBis/pybis/pybis.py b/src/python/PyBis/pybis/pybis.py
index ee2f0ee1a0e..9ae6ad17069 100644
--- a/src/python/PyBis/pybis/pybis.py
+++ b/src/python/PyBis/pybis/pybis.py
@@ -16,6 +16,10 @@ import json
 import re
 from urllib.parse import urlparse
 
+import threading
+from threading import Thread
+from queue import Queue
+
 
 class OpenbisCredentials:
     """Credentials for communicating with openBIS."""
@@ -305,19 +309,7 @@ class Openbis:
             for sample_ident in resp:
                 return Sample(self, sample_ident, resp[sample_ident])
 
-    @staticmethod
-    def download_file(url, filename):
         
-        # create the necessary directory structure if they don't exist yet
-        os.makedirs(os.path.dirname(filename), exist_ok=True)
-
-        # request the file in streaming mode
-        r = requests.get(url, stream=True)
-        with open(filename, 'wb') as f:
-            for chunk in r.iter_content(chunk_size=1024): 
-                if chunk: # filter out keep-alive new chunks
-                    f.write(chunk)
-        return filename
 
     def get_samples_with_data(self, sample_identifiers):
         """Retrieve metadata for the sample, like get_sample_metadata, but retrieve any data sets as well,
@@ -349,6 +341,46 @@ class Openbis:
         # TODO Implement the logic of this method
 
 
+class DataSetDownloadQueue:
+    
+    def __init__(self, workers=20):
+        # maximum files to be downloaded at once
+        self.download_queue = Queue()
+
+        # define number of threads
+        for t in range(workers):
+            t = Thread(target=self.download_file)
+            t.daemon = True
+            t.start()
+
+
+    def put(self, things):
+        """ expects a list [url, filename] which is put into the download queue
+        """
+        self.download_queue.put(things)
+
+    def join(self):
+        """ needs to be called if you want to wait for all downloads to be finished
+        """
+        self.download_queue.join()
+
+
+    def download_file(self):
+        while True:
+            url, filename = self.download_queue.get()
+            # create the necessary directory structure if they don't exist yet
+            os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+            # request the file in streaming mode
+            r = requests.get(url, stream=True)
+            with open(filename, 'wb') as f:
+                for chunk in r.iter_content(chunk_size=1024): 
+                    if chunk: # filter out keep-alive new chunks
+                        f.write(chunk)
+
+            self.download_queue.task_done()
+
+
 class DataSet(Openbis):
     """objects which contain datasets"""
 
@@ -359,24 +391,36 @@ class DataSet(Openbis):
         self.v1_ds = '/datastore_server/rmi-dss-api-v1.json'
         self.downloadUrl = self.data['dataStore']['downloadUrl']
 
+
     @staticmethod
     def ensure_folder_exists(folder): 
         if not os.path.exists(folder):
             os.makedirs(folder)
 
 
-    def download(self):
+    def download(self, wait_until_finished=False, workers=10):
+        """ download the actual files and put them in the following folder:
+        __current_dir__/hostname/dataset_permid/
+        """
+
         base_url = self.downloadUrl + '/datastore_server/' + self.permid + '/'
 
+        queue = DataSetDownloadQueue(workers=workers)
+
+        # get file list and start download
         for file in self.get_file_list(recursive=True):
             if file['isDirectory']:
-
                 folder = os.path.join(self.openbis.hostname, self.permid)
                 DataSet.ensure_folder_exists(folder)
             else:
                 download_url = base_url + file['pathInDataSet'] + '?sessionID=' + self.openbis.token 
                 filename = os.path.join(self.openbis.hostname, self.permid, file['pathInDataSet'])
-                DataSet.download_file(download_url, filename)
+                queue.put([download_url, filename])
+
+        # wait until all files have downloaded
+        if wait_until_finished:
+            queue.join()
+
 
     def get_parents(self):
         parents = []
-- 
GitLab