[Notes] [Git][BuildStream/buildstream][raoul/802-refactor-artifactcache] 5 commits: casremote.py: Move remote CAS classes into its own file



Title: GitLab

Raoul Hidalgo Charman pushed to branch raoul/802-refactor-artifactcache at BuildStream / buildstream

Commits:

12 changed files:

Changes:

  • buildstream/_artifactcache.py
    ... ... @@ -19,18 +19,16 @@
    19 19
     
    
    20 20
     import multiprocessing
    
    21 21
     import os
    
    22
    -import signal
    
    23 22
     import string
    
    24 23
     from collections.abc import Mapping
    
    25 24
     
    
    26 25
     from .types import _KeyStrength
    
    27 26
     from ._exceptions import ArtifactError, CASError, LoadError, LoadErrorReason
    
    28 27
     from ._message import Message, MessageType
    
    29
    -from . import _signals
    
    30 28
     from . import utils
    
    31 29
     from . import _yaml
    
    32 30
     
    
    33
    -from ._cas import CASRemote, CASRemoteSpec
    
    31
    +from ._cas import BlobNotFound, CASRemote, CASRemoteSpec
    
    34 32
     
    
    35 33
     
    
    36 34
     CACHE_SIZE_FILE = "cache_size"
    
    ... ... @@ -375,20 +373,8 @@ class ArtifactCache():
    375 373
             remotes = {}
    
    376 374
             q = multiprocessing.Queue()
    
    377 375
             for remote_spec in remote_specs:
    
    378
    -            # Use subprocess to avoid creation of gRPC threads in main BuildStream process
    
    379
    -            # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    380
    -            p = multiprocessing.Process(target=self.cas.initialize_remote, args=(remote_spec, q))
    
    381 376
     
    
    382
    -            try:
    
    383
    -                # Keep SIGINT blocked in the child process
    
    384
    -                with _signals.blocked([signal.SIGINT], ignore=False):
    
    385
    -                    p.start()
    
    386
    -
    
    387
    -                error = q.get()
    
    388
    -                p.join()
    
    389
    -            except KeyboardInterrupt:
    
    390
    -                utils._kill_process_tree(p.pid)
    
    391
    -                raise
    
    377
    +            error = CASRemote.check_remote(remote_spec, self.context.tmpdir, q)
    
    392 378
     
    
    393 379
                 if error and on_failure:
    
    394 380
                     on_failure(remote_spec.url, error)
    
    ... ... @@ -399,7 +385,7 @@ class ArtifactCache():
    399 385
                     if remote_spec.push:
    
    400 386
                         self._has_push_remotes = True
    
    401 387
     
    
    402
    -                remotes[remote_spec.url] = CASRemote(remote_spec)
    
    388
    +                remotes[remote_spec.url] = CASRemote(remote_spec, self.context.tmpdir)
    
    403 389
     
    
    404 390
             for project in self.context.get_projects():
    
    405 391
                 remote_specs = self.global_remote_specs
    
    ... ... @@ -621,16 +607,41 @@ class ArtifactCache():
    621 607
     
    
    622 608
             for remote in push_remotes:
    
    623 609
                 remote.init()
    
    610
    +            skipped_remote = True
    
    624 611
                 display_key = element._get_brief_display_key()
    
    625 612
                 element.status("Pushing artifact {} -> {}".format(display_key, remote.spec.url))
    
    626 613
     
    
    627
    -            if self.cas.push(refs, remote):
    
    628
    -                element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url))
    
    614
    +            try:
    
    615
    +                for ref in refs:
    
    616
    +                    # Check whether ref is already on the server in which case
    
    617
    +                    # there is no need to push the ref
    
    618
    +                    root_digest = self.cas.resolve_ref(ref)
    
    619
    +                    response = remote.get_reference(ref)
    
    620
    +                    if (response is not None and
    
    621
    +                            response.hash == root_digest.hash and
    
    622
    +                            response.size_bytes == root_digest.size_bytes):
    
    623
    +                        element.info("Remote ({}) already has {} cached".format(
    
    624
    +                            remote.spec.url, element._get_brief_display_key()))
    
    625
    +                        continue
    
    626
    +
    
    627
    +                    # upload blobs
    
    628
    +                    self._send_directory(root_digest, remote)
    
    629
    +                    remote.update_reference(ref, root_digest)
    
    630
    +
    
    631
    +                    skipped_remote = False
    
    632
    +
    
    633
    +            except CASError as e:
    
    634
    +                if str(e.reason) == "StatusCode.RESOURCE_EXHAUSTED":
    
    635
    +                    element.warn("Failed to push element to {}: Resource exhuasted"
    
    636
    +                                 .format(remote.spec.url))
    
    637
    +                    continue
    
    638
    +                else:
    
    639
    +                    raise ArtifactError("Failed to push refs {}: {}".format(refs, e),
    
    640
    +                                        temporary=True) from e
    
    641
    +
    
    642
    +            if skipped_remote is False:
    
    629 643
                     pushed = True
    
    630
    -            else:
    
    631
    -                element.info("Remote ({}) already has {} cached".format(
    
    632
    -                    remote.spec.url, element._get_brief_display_key()
    
    633
    -                ))
    
    644
    +                element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url))
    
    634 645
     
    
    635 646
             return pushed
    
    636 647
     
    
    ... ... @@ -658,19 +669,31 @@ class ArtifactCache():
    658 669
                     display_key = element._get_brief_display_key()
    
    659 670
                     element.status("Pulling artifact {} <- {}".format(display_key, remote.spec.url))
    
    660 671
     
    
    661
    -                if self.cas.pull(ref, remote, progress=progress, subdir=subdir, excluded_subdirs=excluded_subdirs):
    
    662
    -                    element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url))
    
    663
    -                    if subdir:
    
    664
    -                        # Attempt to extract subdir into artifact extract dir if it already exists
    
    665
    -                        # without containing the subdir. If the respective artifact extract dir does not
    
    666
    -                        # exist a complete extraction will complete.
    
    667
    -                        self.extract(element, key, subdir)
    
    668
    -                    # no need to pull from additional remotes
    
    669
    -                    return True
    
    670
    -                else:
    
    672
    +                root_digest = remote.get_reference(ref)
    
    673
    +
    
    674
    +                if not root_digest:
    
    671 675
                         element.info("Remote ({}) does not have {} cached".format(
    
    672
    -                        remote.spec.url, element._get_brief_display_key()
    
    673
    -                    ))
    
    676
    +                        remote.spec.url, element._get_brief_display_key()))
    
    677
    +                    continue
    
    678
    +
    
    679
    +                try:
    
    680
    +                    self._fetch_directory(remote, root_digest, excluded_subdirs)
    
    681
    +                except BlobNotFound:
    
    682
    +                    element.info("Remote ({}) is missing blobs for {}".format(
    
    683
    +                        remote.spec.url, element._get_brief_display_key()))
    
    684
    +                    continue
    
    685
    +
    
    686
    +                self.cas.set_ref(ref, root_digest)
    
    687
    +
    
    688
    +                if subdir:
    
    689
    +                    # Attempt to extract subdir into artifact extract dir if it already exists
    
    690
    +                    # without containing the subdir. If the respective artifact extract dir does not
    
    691
    +                    # exist a complete extraction will complete.
    
    692
    +                    self.extract(element, key, subdir)
    
    693
    +
    
    694
    +                element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url))
    
    695
    +                # no need to pull from additional remotes
    
    696
    +                return True
    
    674 697
     
    
    675 698
                 except CASError as e:
    
    676 699
                     raise ArtifactError("Failed to pull artifact {}: {}".format(
    
    ... ... @@ -685,15 +708,26 @@ class ArtifactCache():
    685 708
         #
    
    686 709
         # Args:
    
    687 710
         #     project (Project): The current project
    
    688
    -    #     digest (Digest): The digest of the tree
    
    711
    +    #     tree_digest (Digest): The digest of the tree
    
    689 712
         #
    
    690
    -    def pull_tree(self, project, digest):
    
    713
    +    def pull_tree(self, project, tree_digest):
    
    691 714
             for remote in self._remotes[project]:
    
    692
    -            digest = self.cas.pull_tree(remote, digest)
    
    693
    -
    
    694
    -            if digest:
    
    695
    -                # no need to pull from additional remotes
    
    696
    -                return digest
    
    715
    +            try:
    
    716
    +                for blob_digest in remote.yield_tree_digests(tree_digest):
    
    717
    +                    if self.cas.check_blob(blob_digest):
    
    718
    +                        continue
    
    719
    +                    remote.request_blob(blob_digest)
    
    720
    +                    for blob_file in remote.get_blobs():
    
    721
    +                        self.cas.add_object(path=blob_file.name, link_directly=True)
    
    722
    +
    
    723
    +                # Get the last batch
    
    724
    +                for blob_file in remote.get_blobs(complete_batch=True):
    
    725
    +                    self.cas.add_object(path=blob_file.name, link_directly=True)
    
    726
    +
    
    727
    +            except BlobNotFound:
    
    728
    +                continue
    
    729
    +            else:
    
    730
    +                return tree_digest
    
    697 731
     
    
    698 732
             return None
    
    699 733
     
    
    ... ... @@ -722,7 +756,7 @@ class ArtifactCache():
    722 756
                 return
    
    723 757
     
    
    724 758
             for remote in push_remotes:
    
    725
    -            self.cas.push_directory(remote, directory)
    
    759
    +            self._send_directory(directory.ref, remote)
    
    726 760
     
    
    727 761
         # push_message():
    
    728 762
         #
    
    ... ... @@ -747,7 +781,7 @@ class ArtifactCache():
    747 781
                                     "servers are configured as push remotes.")
    
    748 782
     
    
    749 783
             for remote in push_remotes:
    
    750
    -            message_digest = self.cas.push_message(remote, message)
    
    784
    +            message_digest = remote.push_message(message)
    
    751 785
     
    
    752 786
             return message_digest
    
    753 787
     
    
    ... ... @@ -807,6 +841,14 @@ class ArtifactCache():
    807 841
             with self.context.timed_activity("Initializing remote caches", silent_nested=True):
    
    808 842
                 self.initialize_remotes(on_failure=remote_failed)
    
    809 843
     
    
    844
    +    def _send_directory(self, root_digest, remote):
    
    845
    +        required_blobs = self.cas.yield_directory_digests(root_digest)
    
    846
    +        missing_blobs = remote.find_missing_blobs(required_blobs)
    
    847
    +        for blob in missing_blobs.values():
    
    848
    +            blob_file = self.cas.objpath(blob)
    
    849
    +            remote.upload_blob(blob, blob_file)
    
    850
    +        remote.send_update_batch()
    
    851
    +
    
    810 852
         # _write_cache_size()
    
    811 853
         #
    
    812 854
         # Writes the given size of the artifact to the cache's size file
    
    ... ... @@ -931,6 +973,19 @@ class ArtifactCache():
    931 973
             stat = os.statvfs(volume)
    
    932 974
             return stat.f_bsize * stat.f_bavail, stat.f_bsize * stat.f_blocks
    
    933 975
     
    
    976
    +    def _fetch_directory(self, remote, root_digest, excluded_subdirs):
    
    977
    +        for blob_digest in remote.yield_directory_digests(
    
    978
    +                root_digest, excluded_subdirs=excluded_subdirs):
    
    979
    +            if self.cas.check_blob(blob_digest):
    
    980
    +                continue
    
    981
    +            remote.request_blob(blob_digest)
    
    982
    +            for blob_file in remote.get_blobs():
    
    983
    +                self.cas.add_object(path=blob_file.name, link_directly=True)
    
    984
    +
    
    985
    +        # Request final CAS batch
    
    986
    +        for blob_file in remote.get_blobs(complete_batch=True):
    
    987
    +            self.cas.add_object(path=blob_file.name, link_directly=True)
    
    988
    +
    
    934 989
     
    
    935 990
     # _configured_remote_artifact_cache_specs():
    
    936 991
     #
    

  • buildstream/_cas/__init__.py
    ... ... @@ -17,4 +17,5 @@
    17 17
     #  Authors:
    
    18 18
     #        Tristan Van Berkom <tristan vanberkom codethink co uk>
    
    19 19
     
    
    20
    -from .cascache import CASCache, CASRemote, CASRemoteSpec
    20
    +from .cascache import CASCache
    
    21
    +from .casremote import CASRemote, CASRemoteSpec, BlobNotFound

  • buildstream/_cas/cascache.py
    ... ... @@ -17,85 +17,16 @@
    17 17
     #  Authors:
    
    18 18
     #        Jürg Billeter <juerg billeter codethink co uk>
    
    19 19
     
    
    20
    -from collections import namedtuple
    
    21 20
     import hashlib
    
    22
    -import itertools
    
    23
    -import io
    
    24 21
     import os
    
    25 22
     import stat
    
    26 23
     import tempfile
    
    27
    -import uuid
    
    28 24
     import contextlib
    
    29
    -from urllib.parse import urlparse
    
    30 25
     
    
    31
    -import grpc
    
    32
    -
    
    33
    -from .._protos.google.rpc import code_pb2
    
    34
    -from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
    
    35
    -from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    36
    -from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
    
    26
    +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    37 27
     
    
    38 28
     from .. import utils
    
    39
    -from .._exceptions import CASError, LoadError, LoadErrorReason
    
    40
    -from .. import _yaml
    
    41
    -
    
    42
    -
    
    43
    -# The default limit for gRPC messages is 4 MiB.
    
    44
    -# Limit payload to 1 MiB to leave sufficient headroom for metadata.
    
    45
    -_MAX_PAYLOAD_BYTES = 1024 * 1024
    
    46
    -
    
    47
    -
    
    48
    -class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')):
    
    49
    -
    
    50
    -    # _new_from_config_node
    
    51
    -    #
    
    52
    -    # Creates an CASRemoteSpec() from a YAML loaded node
    
    53
    -    #
    
    54
    -    @staticmethod
    
    55
    -    def _new_from_config_node(spec_node, basedir=None):
    
    56
    -        _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance_name'])
    
    57
    -        url = _yaml.node_get(spec_node, str, 'url')
    
    58
    -        push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
    
    59
    -        if not url:
    
    60
    -            provenance = _yaml.node_get_provenance(spec_node, 'url')
    
    61
    -            raise LoadError(LoadErrorReason.INVALID_DATA,
    
    62
    -                            "{}: empty artifact cache URL".format(provenance))
    
    63
    -
    
    64
    -        instance_name = _yaml.node_get(spec_node, str, 'instance_name', default_value=None)
    
    65
    -
    
    66
    -        server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
    
    67
    -        if server_cert and basedir:
    
    68
    -            server_cert = os.path.join(basedir, server_cert)
    
    69
    -
    
    70
    -        client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
    
    71
    -        if client_key and basedir:
    
    72
    -            client_key = os.path.join(basedir, client_key)
    
    73
    -
    
    74
    -        client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
    
    75
    -        if client_cert and basedir:
    
    76
    -            client_cert = os.path.join(basedir, client_cert)
    
    77
    -
    
    78
    -        if client_key and not client_cert:
    
    79
    -            provenance = _yaml.node_get_provenance(spec_node, 'client-key')
    
    80
    -            raise LoadError(LoadErrorReason.INVALID_DATA,
    
    81
    -                            "{}: 'client-key' was specified without 'client-cert'".format(provenance))
    
    82
    -
    
    83
    -        if client_cert and not client_key:
    
    84
    -            provenance = _yaml.node_get_provenance(spec_node, 'client-cert')
    
    85
    -            raise LoadError(LoadErrorReason.INVALID_DATA,
    
    86
    -                            "{}: 'client-cert' was specified without 'client-key'".format(provenance))
    
    87
    -
    
    88
    -        return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name)
    
    89
    -
    
    90
    -
    
    91
    -CASRemoteSpec.__new__.__defaults__ = (None, None, None, None)
    
    92
    -
    
    93
    -
    
    94
    -class BlobNotFound(CASError):
    
    95
    -
    
    96
    -    def __init__(self, blob, msg):
    
    97
    -        self.blob = blob
    
    98
    -        super().__init__(msg)
    
    29
    +from .._exceptions import CASCacheError
    
    99 30
     
    
    100 31
     
    
    101 32
     # A CASCache manages a CAS repository as specified in the Remote Execution API.
    
    ... ... @@ -120,7 +51,7 @@ class CASCache():
    120 51
             headdir = os.path.join(self.casdir, 'refs', 'heads')
    
    121 52
             objdir = os.path.join(self.casdir, 'objects')
    
    122 53
             if not (os.path.isdir(headdir) and os.path.isdir(objdir)):
    
    123
    -            raise CASError("CAS repository check failed for '{}'".format(self.casdir))
    
    54
    +            raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir))
    
    124 55
     
    
    125 56
         # contains():
    
    126 57
         #
    
    ... ... @@ -169,7 +100,7 @@ class CASCache():
    169 100
         #     subdir (str): Optional specific dir to extract
    
    170 101
         #
    
    171 102
         # Raises:
    
    172
    -    #     CASError: In cases there was an OSError, or if the ref did not exist.
    
    103
    +    #     CASCacheError: In cases there was an OSError, or if the ref did not exist.
    
    173 104
         #
    
    174 105
         # Returns: path to extracted directory
    
    175 106
         #
    
    ... ... @@ -201,7 +132,7 @@ class CASCache():
    201 132
                     # Another process beat us to rename
    
    202 133
                     pass
    
    203 134
                 except OSError as e:
    
    204
    -                raise CASError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
    
    135
    +                raise CASCacheError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
    
    205 136
     
    
    206 137
             return originaldest
    
    207 138
     
    
    ... ... @@ -245,96 +176,6 @@ class CASCache():
    245 176
     
    
    246 177
             return modified, removed, added
    
    247 178
     
    
    248
    -    def initialize_remote(self, remote_spec, q):
    
    249
    -        try:
    
    250
    -            remote = CASRemote(remote_spec)
    
    251
    -            remote.init()
    
    252
    -
    
    253
    -            request = buildstream_pb2.StatusRequest(instance_name=remote_spec.instance_name)
    
    254
    -            response = remote.ref_storage.Status(request)
    
    255
    -
    
    256
    -            if remote_spec.push and not response.allow_updates:
    
    257
    -                q.put('CAS server does not allow push')
    
    258
    -            else:
    
    259
    -                # No error
    
    260
    -                q.put(None)
    
    261
    -
    
    262
    -        except grpc.RpcError as e:
    
    263
    -            # str(e) is too verbose for errors reported to the user
    
    264
    -            q.put(e.details())
    
    265
    -
    
    266
    -        except Exception as e:               # pylint: disable=broad-except
    
    267
    -            # Whatever happens, we need to return it to the calling process
    
    268
    -            #
    
    269
    -            q.put(str(e))
    
    270
    -
    
    271
    -    # pull():
    
    272
    -    #
    
    273
    -    # Pull a ref from a remote repository.
    
    274
    -    #
    
    275
    -    # Args:
    
    276
    -    #     ref (str): The ref to pull
    
    277
    -    #     remote (CASRemote): The remote repository to pull from
    
    278
    -    #     progress (callable): The progress callback, if any
    
    279
    -    #     subdir (str): The optional specific subdir to pull
    
    280
    -    #     excluded_subdirs (list): The optional list of subdirs to not pull
    
    281
    -    #
    
    282
    -    # Returns:
    
    283
    -    #   (bool): True if pull was successful, False if ref was not available
    
    284
    -    #
    
    285
    -    def pull(self, ref, remote, *, progress=None, subdir=None, excluded_subdirs=None):
    
    286
    -        try:
    
    287
    -            remote.init()
    
    288
    -
    
    289
    -            request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name)
    
    290
    -            request.key = ref
    
    291
    -            response = remote.ref_storage.GetReference(request)
    
    292
    -
    
    293
    -            tree = remote_execution_pb2.Digest()
    
    294
    -            tree.hash = response.digest.hash
    
    295
    -            tree.size_bytes = response.digest.size_bytes
    
    296
    -
    
    297
    -            # Check if the element artifact is present, if so just fetch the subdir.
    
    298
    -            if subdir and os.path.exists(self.objpath(tree)):
    
    299
    -                self._fetch_subdir(remote, tree, subdir)
    
    300
    -            else:
    
    301
    -                # Fetch artifact, excluded_subdirs determined in pullqueue
    
    302
    -                self._fetch_directory(remote, tree, excluded_subdirs=excluded_subdirs)
    
    303
    -
    
    304
    -            self.set_ref(ref, tree)
    
    305
    -
    
    306
    -            return True
    
    307
    -        except grpc.RpcError as e:
    
    308
    -            if e.code() != grpc.StatusCode.NOT_FOUND:
    
    309
    -                raise CASError("Failed to pull ref {}: {}".format(ref, e)) from e
    
    310
    -            else:
    
    311
    -                return False
    
    312
    -        except BlobNotFound as e:
    
    313
    -            return False
    
    314
    -
    
    315
    -    # pull_tree():
    
    316
    -    #
    
    317
    -    # Pull a single Tree rather than a ref.
    
    318
    -    # Does not update local refs.
    
    319
    -    #
    
    320
    -    # Args:
    
    321
    -    #     remote (CASRemote): The remote to pull from
    
    322
    -    #     digest (Digest): The digest of the tree
    
    323
    -    #
    
    324
    -    def pull_tree(self, remote, digest):
    
    325
    -        try:
    
    326
    -            remote.init()
    
    327
    -
    
    328
    -            digest = self._fetch_tree(remote, digest)
    
    329
    -
    
    330
    -            return digest
    
    331
    -
    
    332
    -        except grpc.RpcError as e:
    
    333
    -            if e.code() != grpc.StatusCode.NOT_FOUND:
    
    334
    -                raise
    
    335
    -
    
    336
    -        return None
    
    337
    -
    
    338 179
         # link_ref():
    
    339 180
         #
    
    340 181
         # Add an alias for an existing ref.
    
    ... ... @@ -348,117 +189,6 @@ class CASCache():
    348 189
     
    
    349 190
             self.set_ref(newref, tree)
    
    350 191
     
    
    351
    -    # push():
    
    352
    -    #
    
    353
    -    # Push committed refs to remote repository.
    
    354
    -    #
    
    355
    -    # Args:
    
    356
    -    #     refs (list): The refs to push
    
    357
    -    #     remote (CASRemote): The remote to push to
    
    358
    -    #
    
    359
    -    # Returns:
    
    360
    -    #   (bool): True if any remote was updated, False if no pushes were required
    
    361
    -    #
    
    362
    -    # Raises:
    
    363
    -    #   (CASError): if there was an error
    
    364
    -    #
    
    365
    -    def push(self, refs, remote):
    
    366
    -        skipped_remote = True
    
    367
    -        try:
    
    368
    -            for ref in refs:
    
    369
    -                tree = self.resolve_ref(ref)
    
    370
    -
    
    371
    -                # Check whether ref is already on the server in which case
    
    372
    -                # there is no need to push the ref
    
    373
    -                try:
    
    374
    -                    request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name)
    
    375
    -                    request.key = ref
    
    376
    -                    response = remote.ref_storage.GetReference(request)
    
    377
    -
    
    378
    -                    if response.digest.hash == tree.hash and response.digest.size_bytes == tree.size_bytes:
    
    379
    -                        # ref is already on the server with the same tree
    
    380
    -                        continue
    
    381
    -
    
    382
    -                except grpc.RpcError as e:
    
    383
    -                    if e.code() != grpc.StatusCode.NOT_FOUND:
    
    384
    -                        # Intentionally re-raise RpcError for outer except block.
    
    385
    -                        raise
    
    386
    -
    
    387
    -                self._send_directory(remote, tree)
    
    388
    -
    
    389
    -                request = buildstream_pb2.UpdateReferenceRequest(instance_name=remote.spec.instance_name)
    
    390
    -                request.keys.append(ref)
    
    391
    -                request.digest.hash = tree.hash
    
    392
    -                request.digest.size_bytes = tree.size_bytes
    
    393
    -                remote.ref_storage.UpdateReference(request)
    
    394
    -
    
    395
    -                skipped_remote = False
    
    396
    -        except grpc.RpcError as e:
    
    397
    -            if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED:
    
    398
    -                raise CASError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
    
    399
    -
    
    400
    -        return not skipped_remote
    
    401
    -
    
    402
    -    # push_directory():
    
    403
    -    #
    
    404
    -    # Push the given virtual directory to a remote.
    
    405
    -    #
    
    406
    -    # Args:
    
    407
    -    #     remote (CASRemote): The remote to push to
    
    408
    -    #     directory (Directory): A virtual directory object to push.
    
    409
    -    #
    
    410
    -    # Raises:
    
    411
    -    #     (CASError): if there was an error
    
    412
    -    #
    
    413
    -    def push_directory(self, remote, directory):
    
    414
    -        remote.init()
    
    415
    -
    
    416
    -        self._send_directory(remote, directory.ref)
    
    417
    -
    
    418
    -    # push_message():
    
    419
    -    #
    
    420
    -    # Push the given protobuf message to a remote.
    
    421
    -    #
    
    422
    -    # Args:
    
    423
    -    #     remote (CASRemote): The remote to push to
    
    424
    -    #     message (Message): A protobuf message to push.
    
    425
    -    #
    
    426
    -    # Raises:
    
    427
    -    #     (CASError): if there was an error
    
    428
    -    #
    
    429
    -    def push_message(self, remote, message):
    
    430
    -
    
    431
    -        message_buffer = message.SerializeToString()
    
    432
    -        message_digest = utils._message_digest(message_buffer)
    
    433
    -
    
    434
    -        remote.init()
    
    435
    -
    
    436
    -        with io.BytesIO(message_buffer) as b:
    
    437
    -            self._send_blob(remote, message_digest, b)
    
    438
    -
    
    439
    -        return message_digest
    
    440
    -
    
    441
    -    # verify_digest_on_remote():
    
    442
    -    #
    
    443
    -    # Check whether the object is already on the server in which case
    
    444
    -    # there is no need to upload it.
    
    445
    -    #
    
    446
    -    # Args:
    
    447
    -    #     remote (CASRemote): The remote to check
    
    448
    -    #     digest (Digest): The object digest.
    
    449
    -    #
    
    450
    -    def verify_digest_on_remote(self, remote, digest):
    
    451
    -        remote.init()
    
    452
    -
    
    453
    -        request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=remote.spec.instance_name)
    
    454
    -        request.blob_digests.extend([digest])
    
    455
    -
    
    456
    -        response = remote.cas.FindMissingBlobs(request)
    
    457
    -        if digest in response.missing_blob_digests:
    
    458
    -            return False
    
    459
    -
    
    460
    -        return True
    
    461
    -
    
    462 192
         # objpath():
    
    463 193
         #
    
    464 194
         # Return the path of an object based on its digest.
    
    ... ... @@ -531,7 +261,7 @@ class CASCache():
    531 261
                 pass
    
    532 262
     
    
    533 263
             except OSError as e:
    
    534
    -            raise CASError("Failed to hash object: {}".format(e)) from e
    
    264
    +            raise CASCacheError("Failed to hash object: {}".format(e)) from e
    
    535 265
     
    
    536 266
             return digest
    
    537 267
     
    
    ... ... @@ -572,7 +302,7 @@ class CASCache():
    572 302
                     return digest
    
    573 303
     
    
    574 304
             except FileNotFoundError as e:
    
    575
    -            raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
    
    305
    +            raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
    
    576 306
     
    
    577 307
         # update_mtime()
    
    578 308
         #
    
    ... ... @@ -585,7 +315,7 @@ class CASCache():
    585 315
             try:
    
    586 316
                 os.utime(self._refpath(ref))
    
    587 317
             except FileNotFoundError as e:
    
    588
    -            raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
    
    318
    +            raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
    
    589 319
     
    
    590 320
         # calculate_cache_size()
    
    591 321
         #
    
    ... ... @@ -676,7 +406,7 @@ class CASCache():
    676 406
             # Remove cache ref
    
    677 407
             refpath = self._refpath(ref)
    
    678 408
             if not os.path.exists(refpath):
    
    679
    -            raise CASError("Could not find ref '{}'".format(ref))
    
    409
    +            raise CASCacheError("Could not find ref '{}'".format(ref))
    
    680 410
     
    
    681 411
             os.unlink(refpath)
    
    682 412
     
    
    ... ... @@ -720,6 +450,37 @@ class CASCache():
    720 450
             reachable = set()
    
    721 451
             self._reachable_refs_dir(reachable, tree, update_mtime=True)
    
    722 452
     
    
    453
    +    # Check to see if a blob is in the local CAS
    
    454
    +    # return None if not
    
    455
    +    def check_blob(self, digest):
    
    456
    +        objpath = self.objpath(digest)
    
    457
    +        if os.path.exists(objpath):
    
    458
    +            # already in local repository
    
    459
    +            return objpath
    
    460
    +        else:
    
    461
    +            return None
    
    462
    +
    
    463
    +    def yield_directory_digests(self, directory_digest):
    
    464
    +        # parse directory, and recursively add blobs
    
    465
    +        d = remote_execution_pb2.Digest()
    
    466
    +        d.hash = directory_digest.hash
    
    467
    +        d.size_bytes = directory_digest.size_bytes
    
    468
    +        yield d
    
    469
    +
    
    470
    +        directory = remote_execution_pb2.Directory()
    
    471
    +
    
    472
    +        with open(self.objpath(directory_digest), 'rb') as f:
    
    473
    +            directory.ParseFromString(f.read())
    
    474
    +
    
    475
    +        for filenode in directory.files:
    
    476
    +            d = remote_execution_pb2.Digest()
    
    477
    +            d.hash = filenode.digest.hash
    
    478
    +            d.size_bytes = filenode.digest.size_bytes
    
    479
    +            yield d
    
    480
    +
    
    481
    +        for dirnode in directory.directories:
    
    482
    +            yield from self.yield_directory_digests(dirnode.digest)
    
    483
    +
    
    723 484
         ################################################
    
    724 485
         #             Local Private Methods            #
    
    725 486
         ################################################
    
    ... ... @@ -792,7 +553,7 @@ class CASCache():
    792 553
                     # The process serving the socket can't be cached anyway
    
    793 554
                     pass
    
    794 555
                 else:
    
    795
    -                raise CASError("Unsupported file type for {}".format(full_path))
    
    556
    +                raise CASCacheError("Unsupported file type for {}".format(full_path))
    
    796 557
     
    
    797 558
             return self.add_object(digest=dir_digest,
    
    798 559
                                    buffer=directory.SerializeToString())
    
    ... ... @@ -811,7 +572,7 @@ class CASCache():
    811 572
                 if dirnode.name == name:
    
    812 573
                     return dirnode.digest
    
    813 574
     
    
    814
    -        raise CASError("Subdirectory {} not found".format(name))
    
    575
    +        raise CASCacheError("Subdirectory {} not found".format(name))
    
    815 576
     
    
    816 577
         def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""):
    
    817 578
             dir_a = remote_execution_pb2.Directory()
    
    ... ... @@ -908,429 +669,3 @@ class CASCache():
    908 669
     
    
    909 670
             for dirnode in directory.directories:
    
    910 671
                 yield from self._required_blobs(dirnode.digest)
    911
    -
    
    912
    -    def _fetch_blob(self, remote, digest, stream):
    
    913
    -        resource_name_components = ['blobs', digest.hash, str(digest.size_bytes)]
    
    914
    -
    
    915
    -        if remote.spec.instance_name:
    
    916
    -            resource_name_components.insert(0, remote.spec.instance_name)
    
    917
    -
    
    918
    -        resource_name = '/'.join(resource_name_components)
    
    919
    -
    
    920
    -        request = bytestream_pb2.ReadRequest()
    
    921
    -        request.resource_name = resource_name
    
    922
    -        request.read_offset = 0
    
    923
    -        for response in remote.bytestream.Read(request):
    
    924
    -            stream.write(response.data)
    
    925
    -        stream.flush()
    
    926
    -
    
    927
    -        assert digest.size_bytes == os.fstat(stream.fileno()).st_size
    
    928
    -
    
    929
    -    # _ensure_blob():
    
    930
    -    #
    
    931
    -    # Fetch and add blob if it's not already local.
    
    932
    -    #
    
    933
    -    # Args:
    
    934
    -    #     remote (Remote): The remote to use.
    
    935
    -    #     digest (Digest): Digest object for the blob to fetch.
    
    936
    -    #
    
    937
    -    # Returns:
    
    938
    -    #     (str): The path of the object
    
    939
    -    #
    
    940
    -    def _ensure_blob(self, remote, digest):
    
    941
    -        objpath = self.objpath(digest)
    
    942
    -        if os.path.exists(objpath):
    
    943
    -            # already in local repository
    
    944
    -            return objpath
    
    945
    -
    
    946
    -        with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
    
    947
    -            self._fetch_blob(remote, digest, f)
    
    948
    -
    
    949
    -            added_digest = self.add_object(path=f.name, link_directly=True)
    
    950
    -            assert added_digest.hash == digest.hash
    
    951
    -
    
    952
    -        return objpath
    
    953
    -
    
    954
    -    def _batch_download_complete(self, batch):
    
    955
    -        for digest, data in batch.send():
    
    956
    -            with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
    
    957
    -                f.write(data)
    
    958
    -                f.flush()
    
    959
    -
    
    960
    -                added_digest = self.add_object(path=f.name, link_directly=True)
    
    961
    -                assert added_digest.hash == digest.hash
    
    962
    -
    
    963
    -    # Helper function for _fetch_directory().
    
    964
    -    def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue):
    
    965
    -        self._batch_download_complete(batch)
    
    966
    -
    
    967
    -        # All previously scheduled directories are now locally available,
    
    968
    -        # move them to the processing queue.
    
    969
    -        fetch_queue.extend(fetch_next_queue)
    
    970
    -        fetch_next_queue.clear()
    
    971
    -        return _CASBatchRead(remote)
    
    972
    -
    
    973
    -    # Helper function for _fetch_directory().
    
    974
    -    def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False):
    
    975
    -        in_local_cache = os.path.exists(self.objpath(digest))
    
    976
    -
    
    977
    -        if in_local_cache:
    
    978
    -            # Skip download, already in local cache.
    
    979
    -            pass
    
    980
    -        elif (digest.size_bytes >= remote.max_batch_total_size_bytes or
    
    981
    -              not remote.batch_read_supported):
    
    982
    -            # Too large for batch request, download in independent request.
    
    983
    -            self._ensure_blob(remote, digest)
    
    984
    -            in_local_cache = True
    
    985
    -        else:
    
    986
    -            if not batch.add(digest):
    
    987
    -                # Not enough space left in batch request.
    
    988
    -                # Complete pending batch first.
    
    989
    -                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
    
    990
    -                batch.add(digest)
    
    991
    -
    
    992
    -        if recursive:
    
    993
    -            if in_local_cache:
    
    994
    -                # Add directory to processing queue.
    
    995
    -                fetch_queue.append(digest)
    
    996
    -            else:
    
    997
    -                # Directory will be available after completing pending batch.
    
    998
    -                # Add directory to deferred processing queue.
    
    999
    -                fetch_next_queue.append(digest)
    
    1000
    -
    
    1001
    -        return batch
    
    1002
    -
    
    1003
    -    # _fetch_directory():
    
    1004
    -    #
    
    1005
    -    # Fetches remote directory and adds it to content addressable store.
    
    1006
    -    #
    
    1007
    -    # Fetches files, symbolic links and recursively other directories in
    
    1008
    -    # the remote directory and adds them to the content addressable
    
    1009
    -    # store.
    
    1010
    -    #
    
    1011
    -    # Args:
    
    1012
    -    #     remote (Remote): The remote to use.
    
    1013
    -    #     dir_digest (Digest): Digest object for the directory to fetch.
    
    1014
    -    #     excluded_subdirs (list): The optional list of subdirs to not fetch
    
    1015
    -    #
    
    1016
    -    def _fetch_directory(self, remote, dir_digest, *, excluded_subdirs=None):
    
    1017
    -        fetch_queue = [dir_digest]
    
    1018
    -        fetch_next_queue = []
    
    1019
    -        batch = _CASBatchRead(remote)
    
    1020
    -        if not excluded_subdirs:
    
    1021
    -            excluded_subdirs = []
    
    1022
    -
    
    1023
    -        while len(fetch_queue) + len(fetch_next_queue) > 0:
    
    1024
    -            if not fetch_queue:
    
    1025
    -                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
    
    1026
    -
    
    1027
    -            dir_digest = fetch_queue.pop(0)
    
    1028
    -
    
    1029
    -            objpath = self._ensure_blob(remote, dir_digest)
    
    1030
    -
    
    1031
    -            directory = remote_execution_pb2.Directory()
    
    1032
    -            with open(objpath, 'rb') as f:
    
    1033
    -                directory.ParseFromString(f.read())
    
    1034
    -
    
    1035
    -            for dirnode in directory.directories:
    
    1036
    -                if dirnode.name not in excluded_subdirs:
    
    1037
    -                    batch = self._fetch_directory_node(remote, dirnode.digest, batch,
    
    1038
    -                                                       fetch_queue, fetch_next_queue, recursive=True)
    
    1039
    -
    
    1040
    -            for filenode in directory.files:
    
    1041
    -                batch = self._fetch_directory_node(remote, filenode.digest, batch,
    
    1042
    -                                                   fetch_queue, fetch_next_queue)
    
    1043
    -
    
    1044
    -        # Fetch final batch
    
    1045
    -        self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
    
    1046
    -
    
    1047
    -    def _fetch_subdir(self, remote, tree, subdir):
    
    1048
    -        subdirdigest = self._get_subdir(tree, subdir)
    
    1049
    -        self._fetch_directory(remote, subdirdigest)
    
    1050
    -
    
    1051
    -    def _fetch_tree(self, remote, digest):
    
    1052
    -        # download but do not store the Tree object
    
    1053
    -        with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out:
    
    1054
    -            self._fetch_blob(remote, digest, out)
    
    1055
    -
    
    1056
    -            tree = remote_execution_pb2.Tree()
    
    1057
    -
    
    1058
    -            with open(out.name, 'rb') as f:
    
    1059
    -                tree.ParseFromString(f.read())
    
    1060
    -
    
    1061
    -            tree.children.extend([tree.root])
    
    1062
    -            for directory in tree.children:
    
    1063
    -                for filenode in directory.files:
    
    1064
    -                    self._ensure_blob(remote, filenode.digest)
    
    1065
    -
    
    1066
    -                # place directory blob only in final location when we've downloaded
    
    1067
    -                # all referenced blobs to avoid dangling references in the repository
    
    1068
    -                dirbuffer = directory.SerializeToString()
    
    1069
    -                dirdigest = self.add_object(buffer=dirbuffer)
    
    1070
    -                assert dirdigest.size_bytes == len(dirbuffer)
    
    1071
    -
    
    1072
    -        return dirdigest
    
    1073
    -
    
    1074
    -    def _send_blob(self, remote, digest, stream, u_uid=uuid.uuid4()):
    
    1075
    -        resource_name_components = ['uploads', str(u_uid), 'blobs',
    
    1076
    -                                    digest.hash, str(digest.size_bytes)]
    
    1077
    -
    
    1078
    -        if remote.spec.instance_name:
    
    1079
    -            resource_name_components.insert(0, remote.spec.instance_name)
    
    1080
    -
    
    1081
    -        resource_name = '/'.join(resource_name_components)
    
    1082
    -
    
    1083
    -        def request_stream(resname, instream):
    
    1084
    -            offset = 0
    
    1085
    -            finished = False
    
    1086
    -            remaining = digest.size_bytes
    
    1087
    -            while not finished:
    
    1088
    -                chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
    
    1089
    -                remaining -= chunk_size
    
    1090
    -
    
    1091
    -                request = bytestream_pb2.WriteRequest()
    
    1092
    -                request.write_offset = offset
    
    1093
    -                # max. _MAX_PAYLOAD_BYTES chunks
    
    1094
    -                request.data = instream.read(chunk_size)
    
    1095
    -                request.resource_name = resname
    
    1096
    -                request.finish_write = remaining <= 0
    
    1097
    -
    
    1098
    -                yield request
    
    1099
    -
    
    1100
    -                offset += chunk_size
    
    1101
    -                finished = request.finish_write
    
    1102
    -
    
    1103
    -        response = remote.bytestream.Write(request_stream(resource_name, stream))
    
    1104
    -
    
    1105
    -        assert response.committed_size == digest.size_bytes
    
    1106
    -
    
    1107
    -    def _send_directory(self, remote, digest, u_uid=uuid.uuid4()):
    
    1108
    -        required_blobs = self._required_blobs(digest)
    
    1109
    -
    
    1110
    -        missing_blobs = dict()
    
    1111
    -        # Limit size of FindMissingBlobs request
    
    1112
    -        for required_blobs_group in _grouper(required_blobs, 512):
    
    1113
    -            request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=remote.spec.instance_name)
    
    1114
    -
    
    1115
    -            for required_digest in required_blobs_group:
    
    1116
    -                d = request.blob_digests.add()
    
    1117
    -                d.hash = required_digest.hash
    
    1118
    -                d.size_bytes = required_digest.size_bytes
    
    1119
    -
    
    1120
    -            response = remote.cas.FindMissingBlobs(request)
    
    1121
    -            for missing_digest in response.missing_blob_digests:
    
    1122
    -                d = remote_execution_pb2.Digest()
    
    1123
    -                d.hash = missing_digest.hash
    
    1124
    -                d.size_bytes = missing_digest.size_bytes
    
    1125
    -                missing_blobs[d.hash] = d
    
    1126
    -
    
    1127
    -        # Upload any blobs missing on the server
    
    1128
    -        self._send_blobs(remote, missing_blobs.values(), u_uid)
    
    1129
    -
    
    1130
    -    def _send_blobs(self, remote, digests, u_uid=uuid.uuid4()):
    
    1131
    -        batch = _CASBatchUpdate(remote)
    
    1132
    -
    
    1133
    -        for digest in digests:
    
    1134
    -            with open(self.objpath(digest), 'rb') as f:
    
    1135
    -                assert os.fstat(f.fileno()).st_size == digest.size_bytes
    
    1136
    -
    
    1137
    -                if (digest.size_bytes >= remote.max_batch_total_size_bytes or
    
    1138
    -                        not remote.batch_update_supported):
    
    1139
    -                    # Too large for batch request, upload in independent request.
    
    1140
    -                    self._send_blob(remote, digest, f, u_uid=u_uid)
    
    1141
    -                else:
    
    1142
    -                    if not batch.add(digest, f):
    
    1143
    -                        # Not enough space left in batch request.
    
    1144
    -                        # Complete pending batch first.
    
    1145
    -                        batch.send()
    
    1146
    -                        batch = _CASBatchUpdate(remote)
    
    1147
    -                        batch.add(digest, f)
    
    1148
    -
    
    1149
    -        # Send final batch
    
    1150
    -        batch.send()
    
    1151
    -
    
    1152
    -
    
    1153
    -# Represents a single remote CAS cache.
    
    1154
    -#
    
    1155
    -class CASRemote():
    
    1156
    -    def __init__(self, spec):
    
    1157
    -        self.spec = spec
    
    1158
    -        self._initialized = False
    
    1159
    -        self.channel = None
    
    1160
    -        self.bytestream = None
    
    1161
    -        self.cas = None
    
    1162
    -        self.ref_storage = None
    
    1163
    -        self.batch_update_supported = None
    
    1164
    -        self.batch_read_supported = None
    
    1165
    -        self.capabilities = None
    
    1166
    -        self.max_batch_total_size_bytes = None
    
    1167
    -
    
    1168
    -    def init(self):
    
    1169
    -        if not self._initialized:
    
    1170
    -            url = urlparse(self.spec.url)
    
    1171
    -            if url.scheme == 'http':
    
    1172
    -                port = url.port or 80
    
    1173
    -                self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
    
    1174
    -            elif url.scheme == 'https':
    
    1175
    -                port = url.port or 443
    
    1176
    -
    
    1177
    -                if self.spec.server_cert:
    
    1178
    -                    with open(self.spec.server_cert, 'rb') as f:
    
    1179
    -                        server_cert_bytes = f.read()
    
    1180
    -                else:
    
    1181
    -                    server_cert_bytes = None
    
    1182
    -
    
    1183
    -                if self.spec.client_key:
    
    1184
    -                    with open(self.spec.client_key, 'rb') as f:
    
    1185
    -                        client_key_bytes = f.read()
    
    1186
    -                else:
    
    1187
    -                    client_key_bytes = None
    
    1188
    -
    
    1189
    -                if self.spec.client_cert:
    
    1190
    -                    with open(self.spec.client_cert, 'rb') as f:
    
    1191
    -                        client_cert_bytes = f.read()
    
    1192
    -                else:
    
    1193
    -                    client_cert_bytes = None
    
    1194
    -
    
    1195
    -                credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
    
    1196
    -                                                           private_key=client_key_bytes,
    
    1197
    -                                                           certificate_chain=client_cert_bytes)
    
    1198
    -                self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
    
    1199
    -            else:
    
    1200
    -                raise CASError("Unsupported URL: {}".format(self.spec.url))
    
    1201
    -
    
    1202
    -            self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
    
    1203
    -            self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
    
    1204
    -            self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
    
    1205
    -            self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
    
    1206
    -
    
    1207
    -            self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
    
    1208
    -            try:
    
    1209
    -                request = remote_execution_pb2.GetCapabilitiesRequest(instance_name=self.spec.instance_name)
    
    1210
    -                response = self.capabilities.GetCapabilities(request)
    
    1211
    -                server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
    
    1212
    -                if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
    
    1213
    -                    self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
    
    1214
    -            except grpc.RpcError as e:
    
    1215
    -                # Simply use the defaults for servers that don't implement GetCapabilities()
    
    1216
    -                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    1217
    -                    raise
    
    1218
    -
    
    1219
    -            # Check whether the server supports BatchReadBlobs()
    
    1220
    -            self.batch_read_supported = False
    
    1221
    -            try:
    
    1222
    -                request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=self.spec.instance_name)
    
    1223
    -                response = self.cas.BatchReadBlobs(request)
    
    1224
    -                self.batch_read_supported = True
    
    1225
    -            except grpc.RpcError as e:
    
    1226
    -                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    1227
    -                    raise
    
    1228
    -
    
    1229
    -            # Check whether the server supports BatchUpdateBlobs()
    
    1230
    -            self.batch_update_supported = False
    
    1231
    -            try:
    
    1232
    -                request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self.spec.instance_name)
    
    1233
    -                response = self.cas.BatchUpdateBlobs(request)
    
    1234
    -                self.batch_update_supported = True
    
    1235
    -            except grpc.RpcError as e:
    
    1236
    -                if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
    
    1237
    -                        e.code() != grpc.StatusCode.PERMISSION_DENIED):
    
    1238
    -                    raise
    
    1239
    -
    
    1240
    -            self._initialized = True
    
    1241
    -
    
    1242
    -
    
    1243
    -# Represents a batch of blobs queued for fetching.
    
    1244
    -#
    
    1245
    -class _CASBatchRead():
    
    1246
    -    def __init__(self, remote):
    
    1247
    -        self._remote = remote
    
    1248
    -        self._max_total_size_bytes = remote.max_batch_total_size_bytes
    
    1249
    -        self._request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=remote.spec.instance_name)
    
    1250
    -        self._size = 0
    
    1251
    -        self._sent = False
    
    1252
    -
    
    1253
    -    def add(self, digest):
    
    1254
    -        assert not self._sent
    
    1255
    -
    
    1256
    -        new_batch_size = self._size + digest.size_bytes
    
    1257
    -        if new_batch_size > self._max_total_size_bytes:
    
    1258
    -            # Not enough space left in current batch
    
    1259
    -            return False
    
    1260
    -
    
    1261
    -        request_digest = self._request.digests.add()
    
    1262
    -        request_digest.hash = digest.hash
    
    1263
    -        request_digest.size_bytes = digest.size_bytes
    
    1264
    -        self._size = new_batch_size
    
    1265
    -        return True
    
    1266
    -
    
    1267
    -    def send(self):
    
    1268
    -        assert not self._sent
    
    1269
    -        self._sent = True
    
    1270
    -
    
    1271
    -        if not self._request.digests:
    
    1272
    -            return
    
    1273
    -
    
    1274
    -        batch_response = self._remote.cas.BatchReadBlobs(self._request)
    
    1275
    -
    
    1276
    -        for response in batch_response.responses:
    
    1277
    -            if response.status.code == code_pb2.NOT_FOUND:
    
    1278
    -                raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format(
    
    1279
    -                    response.digest.hash, response.status.code))
    
    1280
    -            if response.status.code != code_pb2.OK:
    
    1281
    -                raise CASError("Failed to download blob {}: {}".format(
    
    1282
    -                    response.digest.hash, response.status.code))
    
    1283
    -            if response.digest.size_bytes != len(response.data):
    
    1284
    -                raise CASError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
    
    1285
    -                    response.digest.hash, response.digest.size_bytes, len(response.data)))
    
    1286
    -
    
    1287
    -            yield (response.digest, response.data)
    
    1288
    -
    
    1289
    -
    
    1290
    -# Represents a batch of blobs queued for upload.
    
    1291
    -#
    
    1292
    -class _CASBatchUpdate():
    
    1293
    -    def __init__(self, remote):
    
    1294
    -        self._remote = remote
    
    1295
    -        self._max_total_size_bytes = remote.max_batch_total_size_bytes
    
    1296
    -        self._request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=remote.spec.instance_name)
    
    1297
    -        self._size = 0
    
    1298
    -        self._sent = False
    
    1299
    -
    
    1300
    -    def add(self, digest, stream):
    
    1301
    -        assert not self._sent
    
    1302
    -
    
    1303
    -        new_batch_size = self._size + digest.size_bytes
    
    1304
    -        if new_batch_size > self._max_total_size_bytes:
    
    1305
    -            # Not enough space left in current batch
    
    1306
    -            return False
    
    1307
    -
    
    1308
    -        blob_request = self._request.requests.add()
    
    1309
    -        blob_request.digest.hash = digest.hash
    
    1310
    -        blob_request.digest.size_bytes = digest.size_bytes
    
    1311
    -        blob_request.data = stream.read(digest.size_bytes)
    
    1312
    -        self._size = new_batch_size
    
    1313
    -        return True
    
    1314
    -
    
    1315
    -    def send(self):
    
    1316
    -        assert not self._sent
    
    1317
    -        self._sent = True
    
    1318
    -
    
    1319
    -        if not self._request.requests:
    
    1320
    -            return
    
    1321
    -
    
    1322
    -        batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
    
    1323
    -
    
    1324
    -        for response in batch_response.responses:
    
    1325
    -            if response.status.code != code_pb2.OK:
    
    1326
    -                raise CASError("Failed to upload blob {}: {}".format(
    
    1327
    -                    response.digest.hash, response.status.code))
    
    1328
    -
    
    1329
    -
    
    1330
    -def _grouper(iterable, n):
    
    1331
    -    while True:
    
    1332
    -        try:
    
    1333
    -            current = next(iterable)
    
    1334
    -        except StopIteration:
    
    1335
    -            return
    
    1336
    -        yield itertools.chain([current], itertools.islice(iterable, n - 1))

  • buildstream/_cas/casremote.py
    1
    +from collections import namedtuple
    
    2
    +import io
    
    3
    +import itertools
    
    4
    +import os
    
    5
    +import multiprocessing
    
    6
    +import signal
    
    7
    +import tempfile
    
    8
    +from urllib.parse import urlparse
    
    9
    +import uuid
    
    10
    +
    
    11
    +import grpc
    
    12
    +
    
    13
    +from .. import _yaml
    
    14
    +from .._protos.google.rpc import code_pb2
    
    15
    +from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
    
    16
    +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    17
    +from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
    
    18
    +
    
    19
    +from .._exceptions import CASRemoteError, LoadError, LoadErrorReason
    
    20
    +from .. import _signals
    
    21
    +from .. import utils
    
    22
    +
    
    23
    +# The default limit for gRPC messages is 4 MiB.
    
    24
    +# Limit payload to 1 MiB to leave sufficient headroom for metadata.
    
    25
    +_MAX_PAYLOAD_BYTES = 1024 * 1024
    
    26
    +
    
    27
    +
    
    28
    +class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')):
    
    29
    +
    
    30
    +    # _new_from_config_node
    
    31
    +    #
    
    32
    +    # Creates an CASRemoteSpec() from a YAML loaded node
    
    33
    +    #
    
    34
    +    @staticmethod
    
    35
    +    def _new_from_config_node(spec_node, basedir=None):
    
    36
    +        _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance_name'])
    
    37
    +        url = _yaml.node_get(spec_node, str, 'url')
    
    38
    +        push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
    
    39
    +        if not url:
    
    40
    +            provenance = _yaml.node_get_provenance(spec_node, 'url')
    
    41
    +            raise LoadError(LoadErrorReason.INVALID_DATA,
    
    42
    +                            "{}: empty artifact cache URL".format(provenance))
    
    43
    +
    
    44
    +        instance_name = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
    
    45
    +
    
    46
    +        server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
    
    47
    +        if server_cert and basedir:
    
    48
    +            server_cert = os.path.join(basedir, server_cert)
    
    49
    +
    
    50
    +        client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
    
    51
    +        if client_key and basedir:
    
    52
    +            client_key = os.path.join(basedir, client_key)
    
    53
    +
    
    54
    +        client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
    
    55
    +        if client_cert and basedir:
    
    56
    +            client_cert = os.path.join(basedir, client_cert)
    
    57
    +
    
    58
    +        if client_key and not client_cert:
    
    59
    +            provenance = _yaml.node_get_provenance(spec_node, 'client-key')
    
    60
    +            raise LoadError(LoadErrorReason.INVALID_DATA,
    
    61
    +                            "{}: 'client-key' was specified without 'client-cert'".format(provenance))
    
    62
    +
    
    63
    +        if client_cert and not client_key:
    
    64
    +            provenance = _yaml.node_get_provenance(spec_node, 'client-cert')
    
    65
    +            raise LoadError(LoadErrorReason.INVALID_DATA,
    
    66
    +                            "{}: 'client-cert' was specified without 'client-key'".format(provenance))
    
    67
    +
    
    68
    +        return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name)
    
    69
    +
    
    70
    +
    
    71
    +CASRemoteSpec.__new__.__defaults__ = (None, None, None, None)
    
    72
    +
    
    73
    +
    
    74
    +class BlobNotFound(CASRemoteError):
    
    75
    +
    
    76
    +    def __init__(self, blob, msg):
    
    77
    +        self.blob = blob
    
    78
    +        super().__init__(msg)
    
    79
    +
    
    80
    +
    
    81
    +# Represents a single remote CAS cache.
    
    82
    +#
    
    83
    +class CASRemote():
    
    84
    +    def __init__(self, spec, tmpdir):
    
    85
    +        self.spec = spec
    
    86
    +        self._initialized = False
    
    87
    +        self.channel = None
    
    88
    +        self.bytestream = None
    
    89
    +        self.cas = None
    
    90
    +        self.ref_storage = None
    
    91
    +        self.batch_update_supported = None
    
    92
    +        self.batch_read_supported = None
    
    93
    +        self.capabilities = None
    
    94
    +        self.max_batch_total_size_bytes = None
    
    95
    +
    
    96
    +        # Need str because python 3.5 and lower doesn't deal with path like
    
    97
    +        # objects here.
    
    98
    +        self.tmpdir = str(tmpdir)
    
    99
    +        os.makedirs(self.tmpdir, exist_ok=True)
    
    100
    +
    
    101
    +        self.__tmp_downloads = []  # files in the tmpdir waiting to be added to local caches
    
    102
    +
    
    103
    +        self.__batch_read = None
    
    104
    +        self.__batch_update = None
    
    105
    +
    
    106
    +    def init(self):
    
    107
    +        if not self._initialized:
    
    108
    +            url = urlparse(self.spec.url)
    
    109
    +            if url.scheme == 'http':
    
    110
    +                port = url.port or 80
    
    111
    +                self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
    
    112
    +            elif url.scheme == 'https':
    
    113
    +                port = url.port or 443
    
    114
    +
    
    115
    +                if self.spec.server_cert:
    
    116
    +                    with open(self.spec.server_cert, 'rb') as f:
    
    117
    +                        server_cert_bytes = f.read()
    
    118
    +                else:
    
    119
    +                    server_cert_bytes = None
    
    120
    +
    
    121
    +                if self.spec.client_key:
    
    122
    +                    with open(self.spec.client_key, 'rb') as f:
    
    123
    +                        client_key_bytes = f.read()
    
    124
    +                else:
    
    125
    +                    client_key_bytes = None
    
    126
    +
    
    127
    +                if self.spec.client_cert:
    
    128
    +                    with open(self.spec.client_cert, 'rb') as f:
    
    129
    +                        client_cert_bytes = f.read()
    
    130
    +                else:
    
    131
    +                    client_cert_bytes = None
    
    132
    +
    
    133
    +                credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
    
    134
    +                                                           private_key=client_key_bytes,
    
    135
    +                                                           certificate_chain=client_cert_bytes)
    
    136
    +                self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
    
    137
    +            else:
    
    138
    +                raise CASRemoteError("Unsupported URL: {}".format(self.spec.url))
    
    139
    +
    
    140
    +            self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
    
    141
    +            self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
    
    142
    +            self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
    
    143
    +            self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
    
    144
    +
    
    145
    +            self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
    
    146
    +            try:
    
    147
    +                request = remote_execution_pb2.GetCapabilitiesRequest()
    
    148
    +                response = self.capabilities.GetCapabilities(request)
    
    149
    +                server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
    
    150
    +                if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
    
    151
    +                    self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
    
    152
    +            except grpc.RpcError as e:
    
    153
    +                # Simply use the defaults for servers that don't implement GetCapabilities()
    
    154
    +                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    155
    +                    raise
    
    156
    +
    
    157
    +            # Check whether the server supports BatchReadBlobs()
    
    158
    +            self.batch_read_supported = False
    
    159
    +            try:
    
    160
    +                request = remote_execution_pb2.BatchReadBlobsRequest()
    
    161
    +                response = self.cas.BatchReadBlobs(request)
    
    162
    +                self.batch_read_supported = True
    
    163
    +                self.__batch_read = _CASBatchRead(self)
    
    164
    +            except grpc.RpcError as e:
    
    165
    +                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    166
    +                    raise
    
    167
    +
    
    168
    +            # Check whether the server supports BatchUpdateBlobs()
    
    169
    +            self.batch_update_supported = False
    
    170
    +            try:
    
    171
    +                request = remote_execution_pb2.BatchUpdateBlobsRequest()
    
    172
    +                response = self.cas.BatchUpdateBlobs(request)
    
    173
    +                self.batch_update_supported = True
    
    174
    +                self.__batch_update = _CASBatchUpdate(self)
    
    175
    +            except grpc.RpcError as e:
    
    176
    +                if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
    
    177
    +                        e.code() != grpc.StatusCode.PERMISSION_DENIED):
    
    178
    +                    raise
    
    179
    +
    
    180
    +            self._initialized = True
    
    181
    +
    
    182
    +    # check_remote
    
    183
    +    #
    
    184
    +    # Used when checking whether remote_specs work in the buildstream main
    
    185
    +    # thread, runs this in a seperate process to avoid creation of gRPC threads
    
    186
    +    # in the main BuildStream process
    
    187
    +    # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    188
    +    @classmethod
    
    189
    +    def check_remote(cls, remote_spec, tmpdir, q):
    
    190
    +
    
    191
    +        def __check_remote():
    
    192
    +            try:
    
    193
    +                remote = cls(remote_spec, tmpdir)
    
    194
    +                remote.init()
    
    195
    +
    
    196
    +                request = buildstream_pb2.StatusRequest()
    
    197
    +                response = remote.ref_storage.Status(request)
    
    198
    +
    
    199
    +                if remote_spec.push and not response.allow_updates:
    
    200
    +                    q.put('CAS server does not allow push')
    
    201
    +                else:
    
    202
    +                    # No error
    
    203
    +                    q.put(None)
    
    204
    +
    
    205
    +            except grpc.RpcError as e:
    
    206
    +                # str(e) is too verbose for errors reported to the user
    
    207
    +                q.put(e.details())
    
    208
    +
    
    209
    +            except Exception as e:               # pylint: disable=broad-except
    
    210
    +                # Whatever happens, we need to return it to the calling process
    
    211
    +                #
    
    212
    +                q.put(str(e))
    
    213
    +
    
    214
    +        p = multiprocessing.Process(target=__check_remote)
    
    215
    +
    
    216
    +        try:
    
    217
    +            # Keep SIGINT blocked in the child process
    
    218
    +            with _signals.blocked([signal.SIGINT], ignore=False):
    
    219
    +                p.start()
    
    220
    +
    
    221
    +            error = q.get()
    
    222
    +            p.join()
    
    223
    +        except KeyboardInterrupt:
    
    224
    +            utils._kill_process_tree(p.pid)
    
    225
    +            raise
    
    226
    +
    
    227
    +        return error
    
    228
    +
    
    229
    +    # verify_digest_on_remote():
    
    230
    +    #
    
    231
    +    # Check whether the object is already on the server in which case
    
    232
    +    # there is no need to upload it.
    
    233
    +    #
    
    234
    +    # Args:
    
    235
    +    #     digest (Digest): The object digest.
    
    236
    +    #
    
    237
    +    def verify_digest_on_remote(self, digest):
    
    238
    +        self.init()
    
    239
    +
    
    240
    +        request = remote_execution_pb2.FindMissingBlobsRequest()
    
    241
    +        request.blob_digests.extend([digest])
    
    242
    +
    
    243
    +        response = self.cas.FindMissingBlobs(request)
    
    244
    +        if digest in response.missing_blob_digests:
    
    245
    +            return False
    
    246
    +
    
    247
    +        return True
    
    248
    +
    
    249
    +    # push_message():
    
    250
    +    #
    
    251
    +    # Push the given protobuf message to a remote.
    
    252
    +    #
    
    253
    +    # Args:
    
    254
    +    #     message (Message): A protobuf message to push.
    
    255
    +    #
    
    256
    +    # Raises:
    
    257
    +    #     (CASRemoteError): if there was an error
    
    258
    +    #
    
    259
    +    def push_message(self, message):
    
    260
    +
    
    261
    +        message_buffer = message.SerializeToString()
    
    262
    +        message_digest = utils._message_digest(message_buffer)
    
    263
    +
    
    264
    +        self.init()
    
    265
    +
    
    266
    +        with io.BytesIO(message_buffer) as b:
    
    267
    +            self._send_blob(message_digest, b)
    
    268
    +
    
    269
    +        return message_digest
    
    270
    +
    
    271
    +    # get_reference():
    
    272
    +    #
    
    273
    +    # Args:
    
    274
    +    #    ref (str): The ref to request
    
    275
    +    #
    
    276
    +    # Returns:
    
    277
    +    #    (digest): digest of ref, None if not found
    
    278
    +    #
    
    279
    +    def get_reference(self, ref):
    
    280
    +        try:
    
    281
    +            self.init()
    
    282
    +
    
    283
    +            request = buildstream_pb2.GetReferenceRequest()
    
    284
    +            request.key = ref
    
    285
    +            return self.ref_storage.GetReference(request).digest
    
    286
    +        except grpc.RpcError as e:
    
    287
    +            if e.code() != grpc.StatusCode.NOT_FOUND:
    
    288
    +                raise CASRemoteError("Failed to find ref {}: {}".format(ref, e)) from e
    
    289
    +            else:
    
    290
    +                return None
    
    291
    +
    
    292
    +    # update_reference():
    
    293
    +    #
    
    294
    +    # Args:
    
    295
    +    #    ref (str): Reference to update
    
    296
    +    #    digest (Digest): New digest to update ref with
    
    297
    +    def update_reference(self, ref, digest):
    
    298
    +        request = buildstream_pb2.UpdateReferenceRequest()
    
    299
    +        request.keys.append(ref)
    
    300
    +        request.digest.hash = digest.hash
    
    301
    +        request.digest.size_bytes = digest.size_bytes
    
    302
    +        self.ref_storage.UpdateReference(request)
    
    303
    +
    
    304
    +    def get_tree_blob(self, tree_digest):
    
    305
    +        self.init()
    
    306
    +        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    307
    +        self._fetch_blob(tree_digest, f)
    
    308
    +
    
    309
    +        tree = remote_execution_pb2.Tree()
    
    310
    +        with open(f.name, 'rb') as tmp:
    
    311
    +            tree.ParseFromString(tmp.read())
    
    312
    +
    
    313
    +        return tree
    
    314
    +
    
    315
    +    # yield_directory_digests():
    
    316
    +    #
    
    317
    +    # Recursively iterates over digests for files, symbolic links and other
    
    318
    +    # directories starting from a root digest
    
    319
    +    #
    
    320
    +    # Args:
    
    321
    +    #     root_digest (digest): The root_digest to get a tree of
    
    322
    +    #     progress (callable): The progress callback, if any
    
    323
    +    #     subdir (str): The optional specific subdir to pull
    
    324
    +    #     excluded_subdirs (list): The optional list of subdirs to not pull
    
    325
    +    #
    
    326
    +    # Returns:
    
    327
    +    #     (iter digests): recursively iterates over digests contained in root directory
    
    328
    +    #
    
    329
    +    def yield_directory_digests(self, root_digest, *, progress=None,
    
    330
    +                                subdir=None, excluded_subdirs=None):
    
    331
    +        self.init()
    
    332
    +
    
    333
    +        # Fetch artifact, excluded_subdirs determined in pullqueue
    
    334
    +        if excluded_subdirs is None:
    
    335
    +            excluded_subdirs = []
    
    336
    +
    
    337
    +        # get directory blob
    
    338
    +        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    339
    +        self._fetch_blob(root_digest, f)
    
    340
    +
    
    341
    +        directory = remote_execution_pb2.Directory()
    
    342
    +        with open(f.name, 'rb') as tmp:
    
    343
    +            directory.ParseFromString(tmp.read())
    
    344
    +
    
    345
    +        yield root_digest
    
    346
    +        for filenode in directory.files:
    
    347
    +            yield filenode.digest
    
    348
    +
    
    349
    +        for dirnode in directory.directories:
    
    350
    +            if dirnode.name not in excluded_subdirs:
    
    351
    +                yield from self.yield_directory_digests(dirnode.digest)
    
    352
    +
    
    353
    +    # yield_tree_digests():
    
    354
    +    #
    
    355
    +    # Fetches a tree file from digests and then iterates over child digests
    
    356
    +    #
    
    357
    +    # Args:
    
    358
    +    #     tree_digest (digest): tree digest
    
    359
    +    #
    
    360
    +    # Returns:
    
    361
    +    #     (iter digests): iterates over digests in tree message
    
    362
    +    def yield_tree_digests(self, tree_digest):
    
    363
    +        self.init()
    
    364
    +
    
    365
    +        # get tree file
    
    366
    +        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    367
    +        self._fetch_blob(tree_digest, f)
    
    368
    +        tree = remote_execution_pb2.Tree()
    
    369
    +        tree.ParseFromString(f.read())
    
    370
    +
    
    371
    +        tree.children.extend([tree.root])
    
    372
    +        for directory in tree.children:
    
    373
    +            for filenode in directory.files:
    
    374
    +                yield filenode.digest
    
    375
    +
    
    376
    +            # add the directory to downloaded tmp files to be added
    
    377
    +            f2 = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    378
    +            f2.write(directory.SerializeToString())
    
    379
    +            self.__tmp_downloads.append(f2)
    
    380
    +
    
    381
    +        # Add the tree directory to downloads right at the end
    
    382
    +        self.__tmp_downloads.append(f)
    
    383
    +
    
    384
    +    # request_blob():
    
    385
    +    #
    
    386
    +    # Request blob, triggering download depending via bytestream or cas
    
    387
    +    # BatchReadBlobs depending on size.
    
    388
    +    #
    
    389
    +    # Args:
    
    390
    +    #    digest (Digest): digest of the requested blob
    
    391
    +    #
    
    392
    +    def request_blob(self, digest):
    
    393
    +        if (not self.batch_read_supported or
    
    394
    +                digest.size_bytes > self.max_batch_total_size_bytes):
    
    395
    +            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    396
    +            self._fetch_blob(digest, f)
    
    397
    +            self.__tmp_downloads.append(f)
    
    398
    +        elif self.__batch_read.add(digest) is False:
    
    399
    +            self._download_batch()
    
    400
    +            self.__batch_read.add(digest)
    
    401
    +
    
    402
    +    # get_blobs():
    
    403
    +    #
    
    404
    +    # Yield over downloaded blobs in the tmp file locations, causing the files
    
    405
    +    # to be deleted once they go out of scope.
    
    406
    +    #
    
    407
    +    # Args:
    
    408
    +    #    complete_batch (bool): download any outstanding batch read request
    
    409
    +    #
    
    410
    +    # Returns:
    
    411
    +    #    iterator over NamedTemporaryFile
    
    412
    +    def get_blobs(self, complete_batch=False):
    
    413
    +        # Send read batch request and download
    
    414
    +        if (complete_batch is True and
    
    415
    +                self.batch_read_supported is True):
    
    416
    +            self._download_batch()
    
    417
    +
    
    418
    +        while self.__tmp_downloads:
    
    419
    +            yield self.__tmp_downloads.pop()
    
    420
    +
    
    421
    +    # upload_blob():
    
    422
    +    #
    
    423
    +    # Push blobs given an iterator over blob files
    
    424
    +    #
    
    425
    +    # Args:
    
    426
    +    #    digest (Digest): digest we want to upload
    
    427
    +    #    blob_file (str): Name of file location
    
    428
    +    #    u_uid (str): Used to identify to the bytestream service
    
    429
    +    #
    
    430
    +    def upload_blob(self, digest, blob_file, u_uid=uuid.uuid4()):
    
    431
    +        with open(blob_file, 'rb') as f:
    
    432
    +            assert os.fstat(f.fileno()).st_size == digest.size_bytes
    
    433
    +
    
    434
    +            if (digest.size_bytes >= self.max_batch_total_size_bytes or
    
    435
    +                    not self.batch_update_supported):
    
    436
    +                # Too large for batch request, upload in independent request.
    
    437
    +                self._send_blob(digest, f, u_uid=u_uid)
    
    438
    +            else:
    
    439
    +                if self.__batch_update.add(digest, f) is False:
    
    440
    +                    self.__batch_update.send()
    
    441
    +                    self.__batch_update = _CASBatchUpdate(self)
    
    442
    +                    self.__batch_update.add(digest, f)
    
    443
    +
    
    444
    +    # send_update_batch():
    
    445
    +    #
    
    446
    +    # Sends anything left in the update batch
    
    447
    +    #
    
    448
    +    def send_update_batch(self):
    
    449
    +        # make sure everything is sent
    
    450
    +        self.__batch_update.send()
    
    451
    +        self.__batch_update = _CASBatchUpdate(self)
    
    452
    +
    
    453
    +    # find_missing_blobs()
    
    454
    +    #
    
    455
    +    # Does FindMissingBlobs request to remote
    
    456
    +    #
    
    457
    +    # Args:
    
    458
    +    #    required_blobs ([Digest]): list of blobs required
    
    459
    +    #
    
    460
    +    # Returns:
    
    461
    +    #    (Dict(Digest)): missing blobs
    
    462
    +    def find_missing_blobs(self, required_blobs):
    
    463
    +        self.init()
    
    464
    +        missing_blobs = dict()
    
    465
    +        # Limit size of FindMissingBlobs request
    
    466
    +        for required_blobs_group in _grouper(required_blobs, 512):
    
    467
    +            request = remote_execution_pb2.FindMissingBlobsRequest()
    
    468
    +
    
    469
    +            for required_digest in required_blobs_group:
    
    470
    +                d = request.blob_digests.add()
    
    471
    +                d.hash = required_digest.hash
    
    472
    +                d.size_bytes = required_digest.size_bytes
    
    473
    +
    
    474
    +            response = self.cas.FindMissingBlobs(request)
    
    475
    +            for missing_digest in response.missing_blob_digests:
    
    476
    +                d = remote_execution_pb2.Digest()
    
    477
    +                d.hash = missing_digest.hash
    
    478
    +                d.size_bytes = missing_digest.size_bytes
    
    479
    +                missing_blobs[d.hash] = d
    
    480
    +
    
    481
    +        return missing_blobs
    
    482
    +
    
    483
    +    ################################################
    
    484
    +    #             Local Private Methods            #
    
    485
    +    ################################################
    
    486
    +    def _fetch_blob(self, digest, stream):
    
    487
    +        resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
    
    488
    +        request = bytestream_pb2.ReadRequest()
    
    489
    +        request.resource_name = resource_name
    
    490
    +        request.read_offset = 0
    
    491
    +        for response in self.bytestream.Read(request):
    
    492
    +            stream.write(response.data)
    
    493
    +        stream.flush()
    
    494
    +
    
    495
    +        assert digest.size_bytes == os.fstat(stream.fileno()).st_size
    
    496
    +
    
    497
    +    def _send_blob(self, digest, stream, u_uid=uuid.uuid4()):
    
    498
    +        resource_name = '/'.join(['uploads', str(u_uid), 'blobs',
    
    499
    +                                  digest.hash, str(digest.size_bytes)])
    
    500
    +
    
    501
    +        def request_stream(resname, instream):
    
    502
    +            offset = 0
    
    503
    +            finished = False
    
    504
    +            remaining = digest.size_bytes
    
    505
    +            while not finished:
    
    506
    +                chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
    
    507
    +                remaining -= chunk_size
    
    508
    +
    
    509
    +                request = bytestream_pb2.WriteRequest()
    
    510
    +                request.write_offset = offset
    
    511
    +                # max. _MAX_PAYLOAD_BYTES chunks
    
    512
    +                request.data = instream.read(chunk_size)
    
    513
    +                request.resource_name = resname
    
    514
    +                request.finish_write = remaining <= 0
    
    515
    +
    
    516
    +                yield request
    
    517
    +
    
    518
    +                offset += chunk_size
    
    519
    +                finished = request.finish_write
    
    520
    +
    
    521
    +        try:
    
    522
    +            response = self.bytestream.Write(request_stream(resource_name, stream))
    
    523
    +        except grpc.RpcError as e:
    
    524
    +            raise CASRemoteError("Failed to upload blob: {}".format(e), reason=e.code())
    
    525
    +
    
    526
    +        assert response.committed_size == digest.size_bytes
    
    527
    +
    
    528
    +    def _download_batch(self):
    
    529
    +        for _, data in self.__batch_read.send():
    
    530
    +            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
    
    531
    +            f.write(data)
    
    532
    +            f.flush()
    
    533
    +            self.__tmp_downloads.append(f)
    
    534
    +
    
    535
    +        self.__batch_read = _CASBatchRead(self)
    
    536
    +
    
    537
    +
    
    538
    +def _grouper(iterable, n):
    
    539
    +    while True:
    
    540
    +        try:
    
    541
    +            current = next(iterable)
    
    542
    +        except StopIteration:
    
    543
    +            return
    
    544
    +        yield itertools.chain([current], itertools.islice(iterable, n - 1))
    
    545
    +
    
    546
    +
    
    547
    +# Represents a batch of blobs queued for fetching.
    
    548
    +#
    
    549
    +class _CASBatchRead():
    
    550
    +    def __init__(self, remote):
    
    551
    +        self._remote = remote
    
    552
    +        self._max_total_size_bytes = remote.max_batch_total_size_bytes
    
    553
    +        self._request = remote_execution_pb2.BatchReadBlobsRequest()
    
    554
    +        self._size = 0
    
    555
    +        self._sent = False
    
    556
    +
    
    557
    +    def add(self, digest):
    
    558
    +        assert not self._sent
    
    559
    +
    
    560
    +        new_batch_size = self._size + digest.size_bytes
    
    561
    +        if new_batch_size > self._max_total_size_bytes:
    
    562
    +            # Not enough space left in current batch
    
    563
    +            return False
    
    564
    +
    
    565
    +        request_digest = self._request.digests.add()
    
    566
    +        request_digest.hash = digest.hash
    
    567
    +        request_digest.size_bytes = digest.size_bytes
    
    568
    +        self._size = new_batch_size
    
    569
    +        return True
    
    570
    +
    
    571
    +    def send(self):
    
    572
    +        assert not self._sent
    
    573
    +        self._sent = True
    
    574
    +
    
    575
    +        if not self._request.digests:
    
    576
    +            return
    
    577
    +
    
    578
    +        try:
    
    579
    +            batch_response = self._remote.cas.BatchReadBlobs(self._request)
    
    580
    +        except grpc.RpcError as e:
    
    581
    +            raise CASRemoteError("Failed to read blob batch: {}".format(e),
    
    582
    +                           reason=e.code()) from e
    
    583
    +
    
    584
    +        for response in batch_response.responses:
    
    585
    +            if response.status.code == code_pb2.NOT_FOUND:
    
    586
    +                raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format(
    
    587
    +                    response.digest.hash, response.status.code))
    
    588
    +            if response.status.code != code_pb2.OK:
    
    589
    +                raise CASRemoteError("Failed to download blob {}: {}".format(
    
    590
    +                    response.digest.hash, response.status.code))
    
    591
    +            if response.digest.size_bytes != len(response.data):
    
    592
    +                raise CASRemoteError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
    
    593
    +                    response.digest.hash, response.digest.size_bytes, len(response.data)))
    
    594
    +
    
    595
    +            yield (response.digest, response.data)
    
    596
    +
    
    597
    +
    
    598
    +# Represents a batch of blobs queued for upload.
    
    599
    +#
    
    600
    +class _CASBatchUpdate():
    
    601
    +    def __init__(self, remote):
    
    602
    +        self._remote = remote
    
    603
    +        self._max_total_size_bytes = remote.max_batch_total_size_bytes
    
    604
    +        self._request = remote_execution_pb2.BatchUpdateBlobsRequest()
    
    605
    +        self._size = 0
    
    606
    +        self._sent = False
    
    607
    +
    
    608
    +    def add(self, digest, stream):
    
    609
    +        assert not self._sent
    
    610
    +
    
    611
    +        new_batch_size = self._size + digest.size_bytes
    
    612
    +        if new_batch_size > self._max_total_size_bytes:
    
    613
    +            # Not enough space left in current batch
    
    614
    +            return False
    
    615
    +
    
    616
    +        blob_request = self._request.requests.add()
    
    617
    +        blob_request.digest.hash = digest.hash
    
    618
    +        blob_request.digest.size_bytes = digest.size_bytes
    
    619
    +        blob_request.data = stream.read(digest.size_bytes)
    
    620
    +        self._size = new_batch_size
    
    621
    +        return True
    
    622
    +
    
    623
    +    def send(self):
    
    624
    +        assert not self._sent
    
    625
    +        self._sent = True
    
    626
    +
    
    627
    +        if not self._request.requests:
    
    628
    +            return
    
    629
    +
    
    630
    +        # Want to raise a CASRemoteError if
    
    631
    +        try:
    
    632
    +            batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
    
    633
    +        except grpc.RpcError as e:
    
    634
    +            raise CASRemoteError("Failed to upload blob batch: {}".format(e),
    
    635
    +                           reason=e.code()) from e
    
    636
    +
    
    637
    +        for response in batch_response.responses:
    
    638
    +            if response.status.code != code_pb2.OK:
    
    639
    +                raise CASRemoteError("Failed to upload blob {}: {}".format(
    
    640
    +                    response.digest.hash, response.status.code))

  • buildstream/_context.py
    ... ... @@ -187,10 +187,11 @@ class Context():
    187 187
             _yaml.node_validate(defaults, [
    
    188 188
                 'sourcedir', 'builddir', 'artifactdir', 'logdir',
    
    189 189
                 'scheduler', 'artifacts', 'logging', 'projects',
    
    190
    -            'cache', 'prompt', 'workspacedir',
    
    190
    +            'cache', 'prompt', 'workspacedir', 'tmpdir'
    
    191 191
             ])
    
    192 192
     
    
    193
    -        for directory in ['sourcedir', 'builddir', 'artifactdir', 'logdir', 'workspacedir']:
    
    193
    +        for directory in ['sourcedir', 'builddir', 'artifactdir', 'logdir',
    
    194
    +                          'tmpdir', 'workspacedir']:
    
    194 195
                 # Allow the ~ tilde expansion and any environment variables in
    
    195 196
                 # path specification in the config files.
    
    196 197
                 #
    

  • buildstream/_exceptions.py
    ... ... @@ -284,6 +284,21 @@ class CASError(BstError):
    284 284
             super().__init__(message, detail=detail, domain=ErrorDomain.CAS, reason=reason, temporary=True)
    
    285 285
     
    
    286 286
     
    
    287
    +# CASRemoteError
    
    288
    +#
    
    289
    +# Raised when errors are encountered in the remote CAS
    
    290
    +class CASRemoteError(CASError):
    
    291
    +    pass
    
    292
    +
    
    293
    +
    
    294
    +# CASCacheError
    
    295
    +#
    
    296
    +# Raised when errors are encountered in the local CASCacheError
    
    297
    +#
    
    298
    +class CASCacheError(CASError):
    
    299
    +    pass
    
    300
    +
    
    301
    +
    
    287 302
     # PipelineError
    
    288 303
     #
    
    289 304
     # Raised from pipeline operations
    

  • buildstream/data/userconfig.yaml
    ... ... @@ -19,6 +19,9 @@ builddir: ${XDG_CACHE_HOME}/buildstream/build
    19 19
     # Location to store local binary artifacts
    
    20 20
     artifactdir: ${XDG_CACHE_HOME}/buildstream/artifacts
    
    21 21
     
    
    22
    +# Location to store temporary files, e.g. used when downloading from a casremote
    
    23
    +tmpdir: ${XDG_CACHE_HOME}/buildstream/tmp
    
    24
    +
    
    22 25
     # Location to store build logs
    
    23 26
     logdir: ${XDG_CACHE_HOME}/buildstream/logs
    
    24 27
     
    

  • buildstream/sandbox/_sandboxremote.py
    ... ... @@ -244,7 +244,7 @@ class SandboxRemote(Sandbox):
    244 244
     
    
    245 245
             context = self._get_context()
    
    246 246
             cascache = context.get_cascache()
    
    247
    -        casremote = CASRemote(self.storage_remote_spec)
    
    247
    +        casremote = CASRemote(self.storage_remote_spec, context.tmpdir)
    
    248 248
     
    
    249 249
             # Now do a pull to ensure we have the necessary parts.
    
    250 250
             dir_digest = cascache.pull_tree(casremote, tree_digest)
    
    ... ... @@ -271,8 +271,9 @@ class SandboxRemote(Sandbox):
    271 271
     
    
    272 272
         def _run(self, command, flags, *, cwd, env):
    
    273 273
             # set up virtual dircetory
    
    274
    +        context = self._get_context()
    
    274 275
             upload_vdir = self.get_virtual_directory()
    
    275
    -        cascache = self._get_context().get_cascache()
    
    276
    +        cascache = context.get_cascache()
    
    276 277
             if isinstance(upload_vdir, FileBasedDirectory):
    
    277 278
                 # Make a new temporary directory to put source in
    
    278 279
                 upload_vdir = CasBasedDirectory(cascache, ref=None)
    
    ... ... @@ -303,7 +304,7 @@ class SandboxRemote(Sandbox):
    303 304
             action_result = self._check_action_cache(action_digest)
    
    304 305
     
    
    305 306
             if not action_result:
    
    306
    -            casremote = CASRemote(self.storage_remote_spec)
    
    307
    +            casremote = CASRemote(self.storage_remote_spec, context.tmpdir)
    
    307 308
     
    
    308 309
                 # Now, push that key (without necessarily needing a ref) to the remote.
    
    309 310
                 try:
    
    ... ... @@ -311,17 +312,17 @@ class SandboxRemote(Sandbox):
    311 312
                 except grpc.RpcError as e:
    
    312 313
                     raise SandboxError("Failed to push source directory to remote: {}".format(e)) from e
    
    313 314
     
    
    314
    -            if not cascache.verify_digest_on_remote(casremote, upload_vdir.ref):
    
    315
    +            if not casremote.verify_digest_on_remote(upload_vdir.ref):
    
    315 316
                     raise SandboxError("Failed to verify that source has been pushed to the remote artifact cache.")
    
    316 317
     
    
    317 318
                 # Push command and action
    
    318 319
                 try:
    
    319
    -                cascache.push_message(casremote, command_proto)
    
    320
    +                casremote.push_message(command_proto)
    
    320 321
                 except grpc.RpcError as e:
    
    321 322
                     raise SandboxError("Failed to push command to remote: {}".format(e))
    
    322 323
     
    
    323 324
                 try:
    
    324
    -                cascache.push_message(casremote, action)
    
    325
    +                casremote.push_message(action)
    
    325 326
                 except grpc.RpcError as e:
    
    326 327
                     raise SandboxError("Failed to push action to remote: {}".format(e))
    
    327 328
     
    

  • tests/artifactcache/pull.py
    ... ... @@ -110,7 +110,7 @@ def test_pull(cli, tmpdir, datafiles):
    110 110
             # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    111 111
             process = multiprocessing.Process(target=_queue_wrapper,
    
    112 112
                                               args=(_test_pull, queue, user_config_file, project_dir,
    
    113
    -                                                artifact_dir, 'target.bst', element_key))
    
    113
    +                                                artifact_dir, tmpdir, 'target.bst', element_key))
    
    114 114
     
    
    115 115
             try:
    
    116 116
                 # Keep SIGINT blocked in the child process
    
    ... ... @@ -126,14 +126,18 @@ def test_pull(cli, tmpdir, datafiles):
    126 126
             assert not error
    
    127 127
             assert cas.contains(element, element_key)
    
    128 128
     
    
    129
    +        # Check that the tmp dir is cleared out
    
    130
    +        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
    
    129 131
     
    
    130
    -def _test_pull(user_config_file, project_dir, artifact_dir,
    
    132
    +
    
    133
    +def _test_pull(user_config_file, project_dir, artifact_dir, tmpdir,
    
    131 134
                    element_name, element_key, queue):
    
    132 135
         # Fake minimal context
    
    133 136
         context = Context()
    
    134 137
         context.load(config=user_config_file)
    
    135 138
         context.artifactdir = artifact_dir
    
    136 139
         context.set_message_handler(message_handler)
    
    140
    +    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
    
    137 141
     
    
    138 142
         # Load the project manually
    
    139 143
         project = Project(project_dir, context)
    
    ... ... @@ -218,7 +222,7 @@ def test_pull_tree(cli, tmpdir, datafiles):
    218 222
             # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    219 223
             process = multiprocessing.Process(target=_queue_wrapper,
    
    220 224
                                               args=(_test_push_tree, queue, user_config_file, project_dir,
    
    221
    -                                                artifact_dir, artifact_digest))
    
    225
    +                                                artifact_dir, tmpdir, artifact_digest))
    
    222 226
     
    
    223 227
             try:
    
    224 228
                 # Keep SIGINT blocked in the child process
    
    ... ... @@ -239,6 +243,9 @@ def test_pull_tree(cli, tmpdir, datafiles):
    239 243
             # Assert that we are not cached locally anymore
    
    240 244
             assert cli.get_element_state(project_dir, 'target.bst') != 'cached'
    
    241 245
     
    
    246
    +        # Check that the tmp dir is cleared out
    
    247
    +        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
    
    248
    +
    
    242 249
             tree_digest = remote_execution_pb2.Digest(hash=tree_hash,
    
    243 250
                                                       size_bytes=tree_size)
    
    244 251
     
    
    ... ... @@ -246,7 +253,7 @@ def test_pull_tree(cli, tmpdir, datafiles):
    246 253
             # Use subprocess to avoid creation of gRPC threads in main BuildStream process
    
    247 254
             process = multiprocessing.Process(target=_queue_wrapper,
    
    248 255
                                               args=(_test_pull_tree, queue, user_config_file, project_dir,
    
    249
    -                                                artifact_dir, tree_digest))
    
    256
    +                                                artifact_dir, tmpdir, tree_digest))
    
    250 257
     
    
    251 258
             try:
    
    252 259
                 # Keep SIGINT blocked in the child process
    
    ... ... @@ -267,13 +274,18 @@ def test_pull_tree(cli, tmpdir, datafiles):
    267 274
             # Ensure the entire Tree stucture has been pulled
    
    268 275
             assert os.path.exists(cas.objpath(directory_digest))
    
    269 276
     
    
    277
    +        # Check that the tmp dir is cleared out
    
    278
    +        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
    
    279
    +
    
    270 280
     
    
    271
    -def _test_push_tree(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
    
    281
    +def _test_push_tree(user_config_file, project_dir, artifact_dir, tmpdir,
    
    282
    +                    artifact_digest, queue):
    
    272 283
         # Fake minimal context
    
    273 284
         context = Context()
    
    274 285
         context.load(config=user_config_file)
    
    275 286
         context.artifactdir = artifact_dir
    
    276 287
         context.set_message_handler(message_handler)
    
    288
    +    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
    
    277 289
     
    
    278 290
         # Load the project manually
    
    279 291
         project = Project(project_dir, context)
    
    ... ... @@ -304,12 +316,14 @@ def _test_push_tree(user_config_file, project_dir, artifact_dir, artifact_digest
    304 316
             queue.put("No remote configured")
    
    305 317
     
    
    306 318
     
    
    307
    -def _test_pull_tree(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
    
    319
    +def _test_pull_tree(user_config_file, project_dir, artifact_dir, tmpdir,
    
    320
    +                    artifact_digest, queue):
    
    308 321
         # Fake minimal context
    
    309 322
         context = Context()
    
    310 323
         context.load(config=user_config_file)
    
    311 324
         context.artifactdir = artifact_dir
    
    312 325
         context.set_message_handler(message_handler)
    
    326
    +    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
    
    313 327
     
    
    314 328
         # Load the project manually
    
    315 329
         project = Project(project_dir, context)
    

  • tests/artifactcache/push.py
    ... ... @@ -89,7 +89,7 @@ def test_push(cli, tmpdir, datafiles):
    89 89
             # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    90 90
             process = multiprocessing.Process(target=_queue_wrapper,
    
    91 91
                                               args=(_test_push, queue, user_config_file, project_dir,
    
    92
    -                                                artifact_dir, 'target.bst', element_key))
    
    92
    +                                                artifact_dir, tmpdir, 'target.bst', element_key))
    
    93 93
     
    
    94 94
             try:
    
    95 95
                 # Keep SIGINT blocked in the child process
    
    ... ... @@ -105,14 +105,18 @@ def test_push(cli, tmpdir, datafiles):
    105 105
             assert not error
    
    106 106
             assert share.has_artifact('test', 'target.bst', element_key)
    
    107 107
     
    
    108
    +        # Check tmpdir for downloads is cleared
    
    109
    +        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
    
    108 110
     
    
    109
    -def _test_push(user_config_file, project_dir, artifact_dir,
    
    111
    +
    
    112
    +def _test_push(user_config_file, project_dir, artifact_dir, tmpdir,
    
    110 113
                    element_name, element_key, queue):
    
    111 114
         # Fake minimal context
    
    112 115
         context = Context()
    
    113 116
         context.load(config=user_config_file)
    
    114 117
         context.artifactdir = artifact_dir
    
    115 118
         context.set_message_handler(message_handler)
    
    119
    +    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
    
    116 120
     
    
    117 121
         # Load the project manually
    
    118 122
         project = Project(project_dir, context)
    
    ... ... @@ -196,9 +200,10 @@ def test_push_directory(cli, tmpdir, datafiles):
    196 200
             queue = multiprocessing.Queue()
    
    197 201
             # Use subprocess to avoid creation of gRPC threads in main BuildStream process
    
    198 202
             # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    199
    -        process = multiprocessing.Process(target=_queue_wrapper,
    
    200
    -                                          args=(_test_push_directory, queue, user_config_file,
    
    201
    -                                                project_dir, artifact_dir, artifact_digest))
    
    203
    +        process = multiprocessing.Process(
    
    204
    +            target=_queue_wrapper,
    
    205
    +            args=(_test_push_directory, queue, user_config_file, project_dir,
    
    206
    +                  artifact_dir, tmpdir, artifact_digest))
    
    202 207
     
    
    203 208
             try:
    
    204 209
                 # Keep SIGINT blocked in the child process
    
    ... ... @@ -215,13 +220,17 @@ def test_push_directory(cli, tmpdir, datafiles):
    215 220
             assert artifact_digest.hash == directory_hash
    
    216 221
             assert share.has_object(artifact_digest)
    
    217 222
     
    
    223
    +        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
    
    218 224
     
    
    219
    -def _test_push_directory(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
    
    225
    +
    
    226
    +def _test_push_directory(user_config_file, project_dir, artifact_dir, tmpdir,
    
    227
    +                         artifact_digest, queue):
    
    220 228
         # Fake minimal context
    
    221 229
         context = Context()
    
    222 230
         context.load(config=user_config_file)
    
    223 231
         context.artifactdir = artifact_dir
    
    224 232
         context.set_message_handler(message_handler)
    
    233
    +    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
    
    225 234
     
    
    226 235
         # Load the project manually
    
    227 236
         project = Project(project_dir, context)
    
    ... ... @@ -273,7 +282,7 @@ def test_push_message(cli, tmpdir, datafiles):
    273 282
             # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
    
    274 283
             process = multiprocessing.Process(target=_queue_wrapper,
    
    275 284
                                               args=(_test_push_message, queue, user_config_file,
    
    276
    -                                                project_dir, artifact_dir))
    
    285
    +                                                project_dir, artifact_dir, tmpdir))
    
    277 286
     
    
    278 287
             try:
    
    279 288
                 # Keep SIGINT blocked in the child process
    
    ... ... @@ -291,13 +300,16 @@ def test_push_message(cli, tmpdir, datafiles):
    291 300
                                                          size_bytes=message_size)
    
    292 301
             assert share.has_object(message_digest)
    
    293 302
     
    
    303
    +        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
    
    304
    +
    
    294 305
     
    
    295
    -def _test_push_message(user_config_file, project_dir, artifact_dir, queue):
    
    306
    +def _test_push_message(user_config_file, project_dir, artifact_dir, tmpdir, queue):
    
    296 307
         # Fake minimal context
    
    297 308
         context = Context()
    
    298 309
         context.load(config=user_config_file)
    
    299 310
         context.artifactdir = artifact_dir
    
    300 311
         context.set_message_handler(message_handler)
    
    312
    +    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
    
    301 313
     
    
    302 314
         # Load the project manually
    
    303 315
         project = Project(project_dir, context)
    

  • tests/integration/pullbuildtrees.py
    ... ... @@ -23,6 +23,7 @@ def default_state(cli, tmpdir, share):
    23 23
             'artifacts': {'url': share.repo, 'push': False},
    
    24 24
             'artifactdir': os.path.join(str(tmpdir), 'artifacts'),
    
    25 25
             'cache': {'pull-buildtrees': False},
    
    26
    +        'tmpdir': os.path.join(str(tmpdir), 'cache', 'tmp'),
    
    26 27
         })
    
    27 28
     
    
    28 29
     
    
    ... ... @@ -79,6 +80,9 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
    79 80
             assert os.path.isdir(buildtreedir)
    
    80 81
             default_state(cli, tmpdir, share1)
    
    81 82
     
    
    83
    +        # Check tmpdir for downloads is cleared
    
    84
    +        assert os.listdir(os.path.join(integration_cache, 'tmp')) == []
    
    85
    +
    
    82 86
             # Pull artifact with pullbuildtrees set in user config, then assert
    
    83 87
             # that pulling with the same user config doesn't creates a pull job,
    
    84 88
             # or when buildtrees cli flag is set.
    
    ... ... @@ -91,6 +95,9 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
    91 95
             assert element_name not in result.get_pulled_elements()
    
    92 96
             default_state(cli, tmpdir, share1)
    
    93 97
     
    
    98
    +        # Check tmpdir for downloads is cleared
    
    99
    +        assert os.listdir(os.path.join(integration_cache, 'tmp')) == []
    
    100
    +
    
    94 101
             # Pull artifact with default config and buildtrees cli flag set, then assert
    
    95 102
             # that pulling with pullbuildtrees set in user config doesn't create a pull
    
    96 103
             # job.
    
    ... ... @@ -101,6 +108,9 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
    101 108
             assert element_name not in result.get_pulled_elements()
    
    102 109
             default_state(cli, tmpdir, share1)
    
    103 110
     
    
    111
    +        # Check tmpdir for downloads is cleared
    
    112
    +        assert os.listdir(os.path.join(integration_cache, 'tmp')) == []
    
    113
    +
    
    104 114
             # Assert that a partial build element (not containing a populated buildtree dir)
    
    105 115
             # can't be pushed to an artifact share, then assert that a complete build element
    
    106 116
             # can be. This will attempt a partial pull from share1 and then a partial push
    

  • tests/testutils/runcli.py
    ... ... @@ -509,7 +509,8 @@ def cli_integration(tmpdir, integration_cache):
    509 509
         # to avoid downloading the huge base-sdk repeatedly
    
    510 510
         fixture.configure({
    
    511 511
             'sourcedir': os.path.join(integration_cache, 'sources'),
    
    512
    -        'artifactdir': os.path.join(integration_cache, 'artifacts')
    
    512
    +        'artifactdir': os.path.join(integration_cache, 'artifacts'),
    
    513
    +        'tmpdir': os.path.join(integration_cache, 'tmp')
    
    513 514
         })
    
    514 515
     
    
    515 516
         return fixture
    
    ... ... @@ -556,6 +557,8 @@ def configured(directory, config=None):
    556 557
             config['builddir'] = os.path.join(directory, 'build')
    
    557 558
         if not config.get('artifactdir', False):
    
    558 559
             config['artifactdir'] = os.path.join(directory, 'artifacts')
    
    560
    +    if not config.get('tmpdir', False):
    
    561
    +        config['tmpdir'] = os.path.join(directory, 'tmp')
    
    559 562
         if not config.get('logdir', False):
    
    560 563
             config['logdir'] = os.path.join(directory, 'logs')
    
    561 564
     
    



  • [Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]