[Notes] [Git][BuildStream/buildstream][raoul/802-refactor-artifactcache] _casremote.py: Add batching to pull command



Title: GitLab

Raoul Hidalgo Charman pushed to branch raoul/802-refactor-artifactcache at BuildStream / buildstream

Commits:

2 changed files:

Changes:

  • buildstream/_artifactcache.py
    ... ... @@ -656,6 +656,11 @@ class ArtifactCache():
    656 656
                                 remote.request_blob(blob_digest)
    
    657 657
                                 for blob_file in remote.get_blobs():
    
    658 658
                                     self.cas.add_object(path=blob_file.name, link_directly=True)
    
    659
    +
    
    660
    +                        # request the final CAS batch
    
    661
    +                        for blob_file in remote.get_blobs(request_batch=True):
    
    662
    +                            self.cas.add_object(path=blob_file.name, link_directly=True)
    
    663
    +
    
    659 664
                             self.cas.set_ref(ref, root_digest)
    
    660 665
                         except BlobNotFound:
    
    661 666
                             element.info("Remote ({}) is missing blobs for {}".format(
    

  • buildstream/_cas/casremote.py
    ... ... @@ -95,6 +95,9 @@ class CASRemote():
    95 95
     
    
    96 96
             self.__tmp_downloads = []  # files in the tmpdir waiting to be added to local caches
    
    97 97
     
    
    98
    +        self.__batch_read = None
    
    99
    +        self.__batch_update = None
    
    100
    +
    
    98 101
         def init(self):
    
    99 102
             if not self._initialized:
    
    100 103
                 url = urlparse(self.spec.url)
    
    ... ... @@ -152,6 +155,7 @@ class CASRemote():
    152 155
                     request = remote_execution_pb2.BatchReadBlobsRequest()
    
    153 156
                     response = self.cas.BatchReadBlobs(request)
    
    154 157
                     self.batch_read_supported = True
    
    158
    +                self.__batch_read = _CASBatchRead(self)
    
    155 159
                 except grpc.RpcError as e:
    
    156 160
                     if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    157 161
                         raise
    
    ... ... @@ -162,6 +166,7 @@ class CASRemote():
    162 166
                     request = remote_execution_pb2.BatchUpdateBlobsRequest()
    
    163 167
                     response = self.cas.BatchUpdateBlobs(request)
    
    164 168
                     self.batch_update_supported = True
    
    169
    +                self.__batch_update = _CASBatchUpdate(self)
    
    165 170
                 except grpc.RpcError as e:
    
    166 171
                     if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
    
    167 172
                             e.code() != grpc.StatusCode.PERMISSION_DENIED):
    
    ... ... @@ -298,18 +303,21 @@ class CASRemote():
    298 303
     
    
    299 304
         # request_blob():
    
    300 305
         #
    
    301
    -    # Request blob and returns path to tmpdir location
    
    306
    +    # Request blob, triggering download depending via bytestream or cas
    
    307
    +    # BatchReadBlobs depending on size.
    
    302 308
         #
    
    303 309
         # Args:
    
    304 310
         #    digest (Digest): digest of the requested blob
    
    305
    -    #    path (str): tmpdir locations of downloaded blobs
    
    306 311
         #
    
    307 312
         def request_blob(self, digest):
    
    308
    -        # TODO expand for adding to batches some other logic
    
    309
    -        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    310
    -        self._fetch_blob(digest, f)
    
    311
    -        self.__tmp_downloads.append(f)
    
    312
    -        return f.name
    
    313
    +        if (not self.batch_read_supported or
    
    314
    +                digest.size_bytes > self.max_batch_total_size_bytes):
    
    315
    +            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    316
    +            self._fetch_blob(digest, f)
    
    317
    +            self.__tmp_downloads.append(f)
    
    318
    +        elif self.__batch_read.add(digest) is False:
    
    319
    +            self._download_batch()
    
    320
    +            self.__batch_read.add(digest)
    
    313 321
     
    
    314 322
         # get_blobs():
    
    315 323
         #
    
    ... ... @@ -318,7 +326,12 @@ class CASRemote():
    318 326
         #
    
    319 327
         # Returns:
    
    320 328
         #    iterator over NamedTemporaryFile
    
    321
    -    def get_blobs(self):
    
    329
    +    def get_blobs(self, request_batch=False):
    
    330
    +        # Send read batch request and download
    
    331
    +        if (request_batch is True and
    
    332
    +                self.batch_read_supported is True):
    
    333
    +            self._download_batch()
    
    334
    +
    
    322 335
             while self.__tmp_downloads:
    
    323 336
                 yield self.__tmp_downloads.pop()
    
    324 337
     
    
    ... ... @@ -349,18 +362,19 @@ class CASRemote():
    349 362
         #     excluded_subdirs (list): The optional list of subdirs to not fetch
    
    350 363
         #
    
    351 364
         def _yield_directory_digests(self, dir_digest, *, excluded_subdirs=[]):
    
    365
    +        # get directory blob
    
    366
    +        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    367
    +        self._fetch_blob(dir_digest, f)
    
    368
    +        self.__tmp_downloads.append(f)
    
    352 369
     
    
    353
    -        objpath = self.request_blob(dir_digest)
    
    354
    -
    
    370
    +        # need to read in directory structure to iterate over it
    
    355 371
             directory = remote_execution_pb2.Directory()
    
    356
    -
    
    357
    -        with open(objpath, 'rb') as f:
    
    358
    -            directory.ParseFromString(f.read())
    
    372
    +        with open(f.name, 'rb') as tmp:
    
    373
    +            directory.ParseFromString(tmp.read())
    
    359 374
     
    
    360 375
             yield dir_digest
    
    361 376
             for filenode in directory.files:
    
    362 377
                 yield filenode.digest
    
    363
    -
    
    364 378
             for dirnode in directory.directories:
    
    365 379
                 if dirnode.name not in excluded_subdirs:
    
    366 380
                     yield from self._yield_directory_digests(dirnode.digest)
    
    ... ... @@ -393,6 +407,15 @@ class CASRemote():
    393 407
     
    
    394 408
             assert response.committed_size == digest.size_bytes
    
    395 409
     
    
    410
    +    def _download_batch(self):
    
    411
    +        for _, data in self.__batch_read.send():
    
    412
    +            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    413
    +            f.write(data)
    
    414
    +            f.flush()
    
    415
    +            self.__tmp_downloads.append(f)
    
    416
    +
    
    417
    +        self.__batch_read = _CASBatchRead(self)
    
    418
    +
    
    396 419
     
    
    397 420
     # Represents a batch of blobs queued for fetching.
    
    398 421
     #
    



  • [Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]