[Notes] [Git][BuildStream/buildstream][raoul/802-refactor-artifactcache] _casremote.py: Add batching to pull command



Title: GitLab

Raoul Hidalgo Charman pushed to branch raoul/802-refactor-artifactcache at BuildStream / buildstream

Commits:

1 changed file:

Changes:

  • buildstream/_cas/casremote.py
    ... ... @@ -95,6 +95,9 @@ class CASRemote():
    95 95
     
    
    96 96
             self.__tmp_downloads = []  # files in the tmpdir waiting to be added to local caches
    
    97 97
     
    
    98
    +        self.__batch_read = None
    
    99
    +        self.__batch_update = None
    
    100
    +
    
    98 101
         def init(self):
    
    99 102
             if not self._initialized:
    
    100 103
                 url = urlparse(self.spec.url)
    
    ... ... @@ -152,6 +155,7 @@ class CASRemote():
    152 155
                     request = remote_execution_pb2.BatchReadBlobsRequest()
    
    153 156
                     response = self.cas.BatchReadBlobs(request)
    
    154 157
                     self.batch_read_supported = True
    
    158
    +                self.__batch_read = _CASBatchRead(self)
    
    155 159
                 except grpc.RpcError as e:
    
    156 160
                     if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    157 161
                         raise
    
    ... ... @@ -162,6 +166,7 @@ class CASRemote():
    162 166
                     request = remote_execution_pb2.BatchUpdateBlobsRequest()
    
    163 167
                     response = self.cas.BatchUpdateBlobs(request)
    
    164 168
                     self.batch_update_supported = True
    
    169
    +                self.__batch_update = _CASBatchUpdate(self)
    
    165 170
                 except grpc.RpcError as e:
    
    166 171
                     if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
    
    167 172
                             e.code() != grpc.StatusCode.PERMISSION_DENIED):
    
    ... ... @@ -298,18 +303,22 @@ class CASRemote():
    298 303
     
    
    299 304
         # request_blob():
    
    300 305
         #
    
    301
    -    # Request blob and returns path to tmpdir location
    
    306
    +    # Request blob, triggering download depending via bytestream or cas
    
    307
    +    # BatchReadBlobs depending on size.
    
    302 308
         #
    
    303 309
         # Args:
    
    304 310
         #    digest (Digest): digest of the requested blob
    
    305
    -    #    path (str): tmpdir locations of downloaded blobs
    
    306 311
         #
    
    307 312
         def request_blob(self, digest):
    
    308
    -        # TODO expand for adding to batches some other logic
    
    309
    -        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    310
    -        self._fetch_blob(digest, f)
    
    311
    -        self.__tmp_downloads.append(f)
    
    312
    -        return f.name
    
    313
    +        if (not self.batch_read_supported or
    
    314
    +                digest.size_bytes > self.max_batch_total_size_bytes):
    
    315
    +            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    316
    +            self._fetch_blob(digest, f)
    
    317
    +            self.__tmp_downloads.append(f)
    
    318
    +            return
    
    319
    +        if self.__batch_read.add(digest) is False:
    
    320
    +            self._download_batch()
    
    321
    +            self.__batch_read.add(digest)
    
    313 322
     
    
    314 323
         # get_blobs():
    
    315 324
         #
    
    ... ... @@ -318,7 +327,12 @@ class CASRemote():
    318 327
         #
    
    319 328
         # Returns:
    
    320 329
         #    iterator over NamedTemporaryFile
    
    321
    -    def get_blobs(self):
    
    330
    +    def get_blobs(self, request_batch=False):
    
    331
    +        # Send read batch request and download
    
    332
    +        if (request_batch is True and
    
    333
    +                self.batch_read_supported is True):
    
    334
    +            self._download_batch()
    
    335
    +
    
    322 336
             while self.__tmp_downloads:
    
    323 337
                 yield self.__tmp_downloads.pop()
    
    324 338
     
    
    ... ... @@ -349,18 +363,19 @@ class CASRemote():
    349 363
         #     excluded_subdirs (list): The optional list of subdirs to not fetch
    
    350 364
         #
    
    351 365
         def _yield_directory_digests(self, dir_digest, *, excluded_subdirs=[]):
    
    366
    +        # get directory blob
    
    367
    +        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    368
    +        self._fetch_blob(dir_digest, f)
    
    369
    +        self.__tmp_downloads.append(f)
    
    352 370
     
    
    353
    -        objpath = self.request_blob(dir_digest)
    
    354
    -
    
    371
    +        # need to read in directory structure to iterate over it
    
    355 372
             directory = remote_execution_pb2.Directory()
    
    356
    -
    
    357
    -        with open(objpath, 'rb') as f:
    
    358
    -            directory.ParseFromString(f.read())
    
    373
    +        with open(f.name, 'rb') as tmp:
    
    374
    +            directory.ParseFromString(tmp.read())
    
    359 375
     
    
    360 376
             yield dir_digest
    
    361 377
             for filenode in directory.files:
    
    362 378
                 yield filenode.digest
    
    363
    -
    
    364 379
             for dirnode in directory.directories:
    
    365 380
                 if dirnode.name not in excluded_subdirs:
    
    366 381
                     yield dirnode.digest
    
    ... ... @@ -394,6 +409,15 @@ class CASRemote():
    394 409
     
    
    395 410
             assert response.committed_size == digest.size_bytes
    
    396 411
     
    
    412
    +    def _download_batch(self):
    
    413
    +        for _, data in self.__batch_read.send():
    
    414
    +            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    415
    +            f.write(data)
    
    416
    +            f.flush()
    
    417
    +            self.__tmp_downloads.append(f)
    
    418
    +
    
    419
    +        self.__batch_read = _CASBatchRead(self)
    
    420
    +
    
    397 421
     
    
    398 422
     # Represents a batch of blobs queued for fetching.
    
    399 423
     #
    



  • [Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]