... |
... |
@@ -95,6 +95,9 @@ class CASRemote(): |
95
|
95
|
|
96
|
96
|
self.__tmp_downloads = [] # files in the tmpdir waiting to be added to local caches
|
97
|
97
|
|
|
98
|
+ self.__batch_read = None
|
|
99
|
+ self.__batch_update = None
|
|
100
|
+
|
98
|
101
|
def init(self):
|
99
|
102
|
if not self._initialized:
|
100
|
103
|
url = urlparse(self.spec.url)
|
... |
... |
@@ -152,6 +155,7 @@ class CASRemote(): |
152
|
155
|
request = remote_execution_pb2.BatchReadBlobsRequest()
|
153
|
156
|
response = self.cas.BatchReadBlobs(request)
|
154
|
157
|
self.batch_read_supported = True
|
|
158
|
+ self.__batch_read = _CASBatchRead(self)
|
155
|
159
|
except grpc.RpcError as e:
|
156
|
160
|
if e.code() != grpc.StatusCode.UNIMPLEMENTED:
|
157
|
161
|
raise
|
... |
... |
@@ -162,6 +166,7 @@ class CASRemote(): |
162
|
166
|
request = remote_execution_pb2.BatchUpdateBlobsRequest()
|
163
|
167
|
response = self.cas.BatchUpdateBlobs(request)
|
164
|
168
|
self.batch_update_supported = True
|
|
169
|
+ self.__batch_update = _CASBatchUpdate(self)
|
165
|
170
|
except grpc.RpcError as e:
|
166
|
171
|
if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
|
167
|
172
|
e.code() != grpc.StatusCode.PERMISSION_DENIED):
|
... |
... |
@@ -298,18 +303,21 @@ class CASRemote(): |
298
|
303
|
|
299
|
304
|
# request_blob():
|
300
|
305
|
#
|
301
|
|
- # Request blob and returns path to tmpdir location
|
|
306
|
+ # Request blob, triggering download depending via bytestream or cas
|
|
307
|
+ # BatchReadBlobs depending on size.
|
302
|
308
|
#
|
303
|
309
|
# Args:
|
304
|
310
|
# digest (Digest): digest of the requested blob
|
305
|
|
- # path (str): tmpdir locations of downloaded blobs
|
306
|
311
|
#
|
307
|
312
|
def request_blob(self, digest):
|
308
|
|
- # TODO expand for adding to batches some other logic
|
309
|
|
- f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
|
310
|
|
- self._fetch_blob(digest, f)
|
311
|
|
- self.__tmp_downloads.append(f)
|
312
|
|
- return f.name
|
|
313
|
+ if (not self.batch_read_supported or
|
|
314
|
+ digest.size_bytes > self.max_batch_total_size_bytes):
|
|
315
|
+ f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
|
|
316
|
+ self._fetch_blob(digest, f)
|
|
317
|
+ self.__tmp_downloads.append(f)
|
|
318
|
+ elif self.__batch_read.add(digest) is False:
|
|
319
|
+ self._download_batch()
|
|
320
|
+ self.__batch_read.add(digest)
|
313
|
321
|
|
314
|
322
|
# get_blobs():
|
315
|
323
|
#
|
... |
... |
@@ -318,7 +326,12 @@ class CASRemote(): |
318
|
326
|
#
|
319
|
327
|
# Returns:
|
320
|
328
|
# iterator over NamedTemporaryFile
|
321
|
|
- def get_blobs(self):
|
|
329
|
+ def get_blobs(self, request_batch=False):
|
|
330
|
+ # Send read batch request and download
|
|
331
|
+ if (request_batch is True and
|
|
332
|
+ self.batch_read_supported is True):
|
|
333
|
+ self._download_batch()
|
|
334
|
+
|
322
|
335
|
while self.__tmp_downloads:
|
323
|
336
|
yield self.__tmp_downloads.pop()
|
324
|
337
|
|
... |
... |
@@ -349,18 +362,19 @@ class CASRemote(): |
349
|
362
|
# excluded_subdirs (list): The optional list of subdirs to not fetch
|
350
|
363
|
#
|
351
|
364
|
def _yield_directory_digests(self, dir_digest, *, excluded_subdirs=[]):
|
|
365
|
+ # get directory blob
|
|
366
|
+ f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
|
|
367
|
+ self._fetch_blob(dir_digest, f)
|
|
368
|
+ self.__tmp_downloads.append(f)
|
352
|
369
|
|
353
|
|
- objpath = self.request_blob(dir_digest)
|
354
|
|
-
|
|
370
|
+ # need to read in directory structure to iterate over it
|
355
|
371
|
directory = remote_execution_pb2.Directory()
|
356
|
|
-
|
357
|
|
- with open(objpath, 'rb') as f:
|
358
|
|
- directory.ParseFromString(f.read())
|
|
372
|
+ with open(f.name, 'rb') as tmp:
|
|
373
|
+ directory.ParseFromString(tmp.read())
|
359
|
374
|
|
360
|
375
|
yield dir_digest
|
361
|
376
|
for filenode in directory.files:
|
362
|
377
|
yield filenode.digest
|
363
|
|
-
|
364
|
378
|
for dirnode in directory.directories:
|
365
|
379
|
if dirnode.name not in excluded_subdirs:
|
366
|
380
|
yield from self._yield_directory_digests(dirnode.digest)
|
... |
... |
@@ -393,6 +407,15 @@ class CASRemote(): |
393
|
407
|
|
394
|
408
|
assert response.committed_size == digest.size_bytes
|
395
|
409
|
|
|
410
|
+ def _download_batch(self):
|
|
411
|
+ for _, data in self.__batch_read.send():
|
|
412
|
+ f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
|
|
413
|
+ f.write(data)
|
|
414
|
+ f.flush()
|
|
415
|
+ self.__tmp_downloads.append(f)
|
|
416
|
+
|
|
417
|
+ self.__batch_read = _CASBatchRead(self)
|
|
418
|
+
|
396
|
419
|
|
397
|
420
|
# Represents a batch of blobs queued for fetching.
|
398
|
421
|
#
|