[Notes] [Git][BuildGrid/buildgrid][mablanch/79-cas-downloader] 8 commits: client/cas.py: Introduce CAS downloader helper class



Title: GitLab

Martin Blanchard pushed to branch mablanch/79-cas-downloader at BuildGrid / buildgrid

Commits:

8 changed files:

Changes:

  • buildgrid/_app/bots/buildbox.py
    ... ... @@ -19,32 +19,34 @@ import tempfile
    19 19
     
    
    20 20
     from google.protobuf import any_pb2
    
    21 21
     
    
    22
    -from buildgrid.settings import HASH_LENGTH
    
    23
    -from buildgrid.client.cas import upload
    
    24
    -from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    25
    -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
    
    22
    +from buildgrid.client.cas import download, upload
    
    26 23
     from buildgrid._exceptions import BotError
    
    27
    -from buildgrid.utils import read_file, write_file, parse_to_pb2_from_fetch
    
    24
    +from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    25
    +from buildgrid.settings import HASH_LENGTH
    
    26
    +from buildgrid.utils import read_file, write_file
    
    28 27
     
    
    29 28
     
    
    30 29
     def work_buildbox(context, lease):
    
    31 30
         """Executes a lease for a build action, using buildbox.
    
    32 31
         """
    
    33 32
     
    
    34
    -    stub_bytestream = bytestream_pb2_grpc.ByteStreamStub(context.cas_channel)
    
    35 33
         local_cas_directory = context.local_cas
    
    34
    +    # instance_name = context.parent
    
    36 35
         logger = context.logger
    
    37 36
     
    
    38 37
         action_digest = remote_execution_pb2.Digest()
    
    39 38
         lease.payload.Unpack(action_digest)
    
    40 39
     
    
    41
    -    action = parse_to_pb2_from_fetch(remote_execution_pb2.Action(),
    
    42
    -                                     stub_bytestream, action_digest)
    
    40
    +    with download(context.cas_channel) as downloader:
    
    41
    +        action = downloader.get_message(action_digest,
    
    42
    +                                        remote_execution_pb2.Action())
    
    43 43
     
    
    44
    -    command = parse_to_pb2_from_fetch(remote_execution_pb2.Command(),
    
    45
    -                                      stub_bytestream, action.command_digest)
    
    44
    +        assert action.command_digest.hash
    
    46 45
     
    
    47
    -    environment = dict()
    
    46
    +        command = downloader.get_message(action.command_digest,
    
    47
    +                                         remote_execution_pb2.Command())
    
    48
    +
    
    49
    +    environment = {}
    
    48 50
         for variable in command.environment_variables:
    
    49 51
             if variable.name not in ['PWD']:
    
    50 52
                 environment[variable.name] = variable.value
    
    ... ... @@ -116,10 +118,11 @@ def work_buildbox(context, lease):
    116 118
     
    
    117 119
                 # TODO: Have BuildBox helping us creating the Tree instance here
    
    118 120
                 # See https://gitlab.com/BuildStream/buildbox/issues/7 for details
    
    119
    -            output_tree = _cas_tree_maker(stub_bytestream, output_digest)
    
    121
    +            with download(context.cas_channel) as downloader:
    
    122
    +                output_tree = _cas_tree_maker(downloader, output_digest)
    
    120 123
     
    
    121
    -            with upload(context.cas_channel) as cas:
    
    122
    -                output_tree_digest = cas.put_message(output_tree)
    
    124
    +            with upload(context.cas_channel) as uploader:
    
    125
    +                output_tree_digest = uploader.put_message(output_tree)
    
    123 126
     
    
    124 127
                 output_directory = remote_execution_pb2.OutputDirectory()
    
    125 128
                 output_directory.tree_digest.CopyFrom(output_tree_digest)
    
    ... ... @@ -135,24 +138,28 @@ def work_buildbox(context, lease):
    135 138
         return lease
    
    136 139
     
    
    137 140
     
    
    138
    -def _cas_tree_maker(stub_bytestream, directory_digest):
    
    141
    +def _cas_tree_maker(cas, directory_digest):
    
    139 142
         # Generates and stores a Tree for a given Directory. This is very inefficient
    
    140 143
         # and only temporary. See https://gitlab.com/BuildStream/buildbox/issues/7.
    
    141 144
         output_tree = remote_execution_pb2.Tree()
    
    142 145
     
    
    143
    -    def list_directories(parent_directory):
    
    144
    -        directory_list = list()
    
    146
    +    def __cas_tree_maker(cas, parent_directory):
    
    147
    +        digests, directories = [], []
    
    145 148
             for directory_node in parent_directory.directories:
    
    146
    -            directory = parse_to_pb2_from_fetch(remote_execution_pb2.Directory(),
    
    147
    -                                                stub_bytestream, directory_node.digest)
    
    148
    -            directory_list.extend(list_directories(directory))
    
    149
    -            directory_list.append(directory)
    
    149
    +            directories.append(remote_execution_pb2.Directory())
    
    150
    +            digests.append(directory_node.digest)
    
    151
    +
    
    152
    +        cas.get_messages(digests, directories)
    
    153
    +
    
    154
    +        for directory in directories[:]:
    
    155
    +            directories.extend(__cas_tree_maker(cas, directory))
    
    156
    +
    
    157
    +        return directories
    
    150 158
     
    
    151
    -        return directory_list
    
    159
    +    root_directory = cas.get_message(directory_digest,
    
    160
    +                                     remote_execution_pb2.Directory())
    
    152 161
     
    
    153
    -    root_directory = parse_to_pb2_from_fetch(remote_execution_pb2.Directory(),
    
    154
    -                                             stub_bytestream, directory_digest)
    
    155
    -    output_tree.children.extend(list_directories(root_directory))
    
    162
    +    output_tree.children.extend(__cas_tree_maker(cas, root_directory))
    
    156 163
         output_tree.root.CopyFrom(root_directory)
    
    157 164
     
    
    158 165
         return output_tree

  • buildgrid/_app/bots/temp_directory.py
    ... ... @@ -19,10 +19,8 @@ import tempfile
    19 19
     
    
    20 20
     from google.protobuf import any_pb2
    
    21 21
     
    
    22
    -from buildgrid.client.cas import upload
    
    22
    +from buildgrid.client.cas import download, upload
    
    23 23
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    24
    -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
    
    25
    -from buildgrid.utils import write_fetch_directory, parse_to_pb2_from_fetch
    
    26 24
     from buildgrid.utils import output_file_maker, output_directory_maker
    
    27 25
     
    
    28 26
     
    
    ... ... @@ -30,29 +28,30 @@ def work_temp_directory(context, lease):
    30 28
         """Executes a lease for a build action, using host tools.
    
    31 29
         """
    
    32 30
     
    
    33
    -    stub_bytestream = bytestream_pb2_grpc.ByteStreamStub(context.cas_channel)
    
    34 31
         instance_name = context.parent
    
    35 32
         logger = context.logger
    
    36 33
     
    
    37 34
         action_digest = remote_execution_pb2.Digest()
    
    38 35
         lease.payload.Unpack(action_digest)
    
    39 36
     
    
    40
    -    action = parse_to_pb2_from_fetch(remote_execution_pb2.Action(),
    
    41
    -                                     stub_bytestream, action_digest, instance_name)
    
    42
    -
    
    43 37
         with tempfile.TemporaryDirectory() as temp_directory:
    
    44
    -        command = parse_to_pb2_from_fetch(remote_execution_pb2.Command(),
    
    45
    -                                          stub_bytestream, action.command_digest, instance_name)
    
    38
    +        with download(context.cas_channel, instance=instance_name) as downloader:
    
    39
    +            action = downloader.get_message(action_digest,
    
    40
    +                                            remote_execution_pb2.Action())
    
    41
    +
    
    42
    +            assert action.command_digest.hash
    
    43
    +
    
    44
    +            command = downloader.get_message(action.command_digest,
    
    45
    +                                             remote_execution_pb2.Command())
    
    46 46
     
    
    47
    -        write_fetch_directory(temp_directory, stub_bytestream,
    
    48
    -                              action.input_root_digest, instance_name)
    
    47
    +            downloader.download_directory(action.input_root_digest, temp_directory)
    
    49 48
     
    
    50 49
             environment = os.environ.copy()
    
    51 50
             for variable in command.environment_variables:
    
    52 51
                 if variable.name not in ['PATH', 'PWD']:
    
    53 52
                     environment[variable.name] = variable.value
    
    54 53
     
    
    55
    -        command_line = list()
    
    54
    +        command_line = []
    
    56 55
             for argument in command.arguments:
    
    57 56
                 command_line.append(argument.strip())
    
    58 57
     
    

  • buildgrid/_app/commands/cmd_execute.py
    ... ... @@ -20,7 +20,6 @@ Execute command
    20 20
     Request work to be executed and monitor status of jobs.
    
    21 21
     """
    
    22 22
     
    
    23
    -import errno
    
    24 23
     import logging
    
    25 24
     import os
    
    26 25
     import stat
    
    ... ... @@ -30,10 +29,9 @@ from urllib.parse import urlparse
    30 29
     import click
    
    31 30
     import grpc
    
    32 31
     
    
    33
    -from buildgrid.client.cas import upload
    
    32
    +from buildgrid.client.cas import download, upload
    
    34 33
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    35
    -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
    
    36
    -from buildgrid.utils import create_digest, write_fetch_blob
    
    34
    +from buildgrid.utils import create_digest
    
    37 35
     
    
    38 36
     from ..cli import pass_context
    
    39 37
     
    
    ... ... @@ -154,8 +152,6 @@ def run_command(context, input_root, commands, output_file, output_directory):
    154 152
                                                       skip_cache_lookup=True)
    
    155 153
         response = stub.Execute(request)
    
    156 154
     
    
    157
    -    stub = bytestream_pb2_grpc.ByteStreamStub(context.channel)
    
    158
    -
    
    159 155
         stream = None
    
    160 156
         for stream in response:
    
    161 157
             context.logger.info(stream)
    
    ... ... @@ -163,21 +159,16 @@ def run_command(context, input_root, commands, output_file, output_directory):
    163 159
         execute_response = remote_execution_pb2.ExecuteResponse()
    
    164 160
         stream.response.Unpack(execute_response)
    
    165 161
     
    
    166
    -    for output_file_response in execute_response.result.output_files:
    
    167
    -        path = os.path.join(output_directory, output_file_response.path)
    
    168
    -
    
    169
    -        if not os.path.exists(os.path.dirname(path)):
    
    162
    +    with download(context.channel, instance=context.instance_name) as downloader:
    
    170 163
     
    
    171
    -            try:
    
    172
    -                os.makedirs(os.path.dirname(path))
    
    164
    +        for output_file_response in execute_response.result.output_files:
    
    165
    +            path = os.path.join(output_directory, output_file_response.path)
    
    173 166
     
    
    174
    -            except OSError as exc:
    
    175
    -                if exc.errno != errno.EEXIST:
    
    176
    -                    raise
    
    167
    +            if not os.path.exists(os.path.dirname(path)):
    
    168
    +                os.makedirs(os.path.dirname(path), exist_ok=True)
    
    177 169
     
    
    178
    -        with open(path, 'wb+') as f:
    
    179
    -            write_fetch_blob(f, stub, output_file_response.digest, context.instance_name)
    
    170
    +            downloader.download_file(output_file_response.digest, path)
    
    180 171
     
    
    181
    -        if output_file_response.path in output_executeables:
    
    182
    -            st = os.stat(path)
    
    183
    -            os.chmod(path, st.st_mode | stat.S_IXUSR)
    172
    +            if output_file_response.path in output_executeables:
    
    173
    +                st = os.stat(path)
    
    174
    +                os.chmod(path, st.st_mode | stat.S_IXUSR)

  • buildgrid/client/cas.py
    ... ... @@ -19,6 +19,7 @@ import os
    19 19
     
    
    20 20
     import grpc
    
    21 21
     
    
    22
    +from buildgrid._exceptions import NotFoundError
    
    22 23
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    23 24
     from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
    
    24 25
     from buildgrid._protos.google.rpc import code_pb2
    
    ... ... @@ -26,6 +27,16 @@ from buildgrid.settings import HASH
    26 27
     from buildgrid.utils import merkle_tree_maker
    
    27 28
     
    
    28 29
     
    
    30
    +# Maximum size for a queueable file:
    
    31
    +FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
    
    32
    +
    
    33
    +# Maximum size for a single gRPC request:
    
    34
    +MAX_REQUEST_SIZE = 2 * 1024 * 1024
    
    35
    +
    
    36
    +# Maximum number of elements per gRPC request:
    
    37
    +MAX_REQUEST_COUNT = 500
    
    38
    +
    
    39
    +
    
    29 40
     class _CallCache:
    
    30 41
         """Per remote grpc.StatusCode.UNIMPLEMENTED call cache."""
    
    31 42
         __calls = {}
    
    ... ... @@ -43,6 +54,401 @@ class _CallCache:
    43 54
             return name in cls.__calls[channel]
    
    44 55
     
    
    45 56
     
    
    57
    +@contextmanager
    
    58
    +def download(channel, instance=None, u_uid=None):
    
    59
    +    """Context manager generator for the :class:`Downloader` class."""
    
    60
    +    downloader = Downloader(channel, instance=instance)
    
    61
    +    try:
    
    62
    +        yield downloader
    
    63
    +    finally:
    
    64
    +        downloader.close()
    
    65
    +
    
    66
    +
    
    67
    +class Downloader:
    
    68
    +    """Remote CAS files, directories and messages download helper.
    
    69
    +
    
    70
    +    The :class:`Downloader` class comes with a generator factory function that
    
    71
    +    can be used together with the `with` statement for context management::
    
    72
    +
    
    73
    +        from buildgrid.client.cas import download
    
    74
    +
    
    75
    +        with download(channel, instance='build') as downloader:
    
    76
    +            downloader.get_message(message_digest)
    
    77
    +    """
    
    78
    +
    
    79
    +    def __init__(self, channel, instance=None):
    
    80
    +        """Initializes a new :class:`Downloader` instance.
    
    81
    +
    
    82
    +        Args:
    
    83
    +            channel (grpc.Channel): A gRPC channel to the CAS endpoint.
    
    84
    +            instance (str, optional): the targeted instance's name.
    
    85
    +        """
    
    86
    +        self.channel = channel
    
    87
    +
    
    88
    +        self.instance_name = instance
    
    89
    +
    
    90
    +        self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel)
    
    91
    +        self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
    
    92
    +
    
    93
    +        self.__file_requests = {}
    
    94
    +        self.__file_request_count = 0
    
    95
    +        self.__file_request_size = 0
    
    96
    +        self.__file_response_size = 0
    
    97
    +
    
    98
    +    # --- Public API ---
    
    99
    +
    
    100
    +    def get_blob(self, digest):
    
    101
    +        """Retrieves a blob from the remote CAS server.
    
    102
    +
    
    103
    +        Args:
    
    104
    +            digest (:obj:`Digest`): the blob's digest to fetch.
    
    105
    +
    
    106
    +        Returns:
    
    107
    +            bytearray: the fetched blob data or None if not found.
    
    108
    +        """
    
    109
    +        try:
    
    110
    +            blob = self._fetch_blob(digest)
    
    111
    +        except NotFoundError:
    
    112
    +            return None
    
    113
    +
    
    114
    +        return blob
    
    115
    +
    
    116
    +    def get_blobs(self, digests):
    
    117
    +        """Retrieves a list of blobs from the remote CAS server.
    
    118
    +
    
    119
    +        Args:
    
    120
    +            digests (list): list of :obj:`Digest`s for the blobs to fetch.
    
    121
    +
    
    122
    +        Returns:
    
    123
    +            list: the fetched blob data list.
    
    124
    +        """
    
    125
    +        return self._fetch_blob_batch(digests)
    
    126
    +
    
    127
    +    def get_message(self, digest, message):
    
    128
    +        """Retrieves a :obj:`Message` from the remote CAS server.
    
    129
    +
    
    130
    +        Args:
    
    131
    +            digest (:obj:`Digest`): the message's digest to fetch.
    
    132
    +            message (:obj:`Message`): an empty message to fill.
    
    133
    +
    
    134
    +        Returns:
    
    135
    +            :obj:`Message`: `message` filled or emptied if not found.
    
    136
    +        """
    
    137
    +        try:
    
    138
    +            message_blob = self._fetch_blob(digest)
    
    139
    +        except NotFoundError:
    
    140
    +            message_blob = None
    
    141
    +
    
    142
    +        if message_blob is not None:
    
    143
    +            message.ParseFromString(message_blob)
    
    144
    +        else:
    
    145
    +            message.Clear()
    
    146
    +
    
    147
    +        return message
    
    148
    +
    
    149
    +    def get_messages(self, digests, messages):
    
    150
    +        """Retrieves a list of :obj:`Message`s from the remote CAS server.
    
    151
    +
    
    152
    +        Note:
    
    153
    +            The `digests` and `messages` list **must** contain the same number
    
    154
    +            of elements.
    
    155
    +
    
    156
    +        Args:
    
    157
    +            digests (list):  list of :obj:`Digest`s for the messages to fetch.
    
    158
    +            messages (list): list of empty :obj:`Message`s to fill.
    
    159
    +
    
    160
    +        Returns:
    
    161
    +            list: the fetched and filled message list.
    
    162
    +        """
    
    163
    +        assert len(digests) == len(messages)
    
    164
    +
    
    165
    +        message_blobs = self._fetch_blob_batch(digests)
    
    166
    +
    
    167
    +        assert len(message_blobs) == len(messages)
    
    168
    +
    
    169
    +        for message, message_blob in zip(messages, message_blobs):
    
    170
    +            message.ParseFromString(message_blob)
    
    171
    +
    
    172
    +        return messages
    
    173
    +
    
    174
    +    def download_file(self, digest, file_path, queue=True):
    
    175
    +        """Retrieves a file from the remote CAS server.
    
    176
    +
    
    177
    +        If queuing is allowed (`queue=True`), the download request **may** be
    
    178
    +        defer. An explicit call to :func:`~flush` can force the request to be
    
    179
    +        send immediately (along with the rest of the queued batch).
    
    180
    +
    
    181
    +        Args:
    
    182
    +            digest (:obj:`Digest`): the file's digest to fetch.
    
    183
    +            file_path (str): absolute or relative path to the local file to write.
    
    184
    +            queue (bool, optional): whether or not the download request may be
    
    185
    +                queued and submitted as part of a batch upload request. Defaults
    
    186
    +                to True.
    
    187
    +
    
    188
    +        Raises:
    
    189
    +            NotFoundError: if `digest` is not present in the remote CAS server.
    
    190
    +            OSError: if `file_path` does not exist or is not readable.
    
    191
    +        """
    
    192
    +        if not os.path.isabs(file_path):
    
    193
    +            file_path = os.path.abspath(file_path)
    
    194
    +
    
    195
    +        if not queue or digest.size_bytes > FILE_SIZE_THRESHOLD:
    
    196
    +            self._fetch_file(digest, file_path)
    
    197
    +        else:
    
    198
    +            self._queue_file(digest, file_path)
    
    199
    +
    
    200
    +    def download_directory(self, digest, directory_path):
    
    201
    +        """Retrieves a :obj:`Directory` from the remote CAS server.
    
    202
    +
    
    203
    +        Args:
    
    204
    +            digest (:obj:`Digest`): the directory's digest to fetch.
    
    205
    +
    
    206
    +        Raises:
    
    207
    +            NotFoundError: if `digest` is not present in the remote CAS server.
    
    208
    +            FileExistsError: if `directory_path` already contains parts of their
    
    209
    +                fetched directory's content.
    
    210
    +        """
    
    211
    +        if not os.path.isabs(directory_path):
    
    212
    +            directory_path = os.path.abspath(directory_path)
    
    213
    +
    
    214
    +        # We want to start fresh here, the rest is very synchronous...
    
    215
    +        self.flush()
    
    216
    +
    
    217
    +        self._fetch_directory(digest, directory_path)
    
    218
    +
    
    219
    +    def flush(self):
    
    220
    +        """Ensures any queued request gets sent."""
    
    221
    +        if self.__file_requests:
    
    222
    +            self._fetch_file_batch(self.__file_requests)
    
    223
    +
    
    224
    +            self.__file_requests.clear()
    
    225
    +            self.__file_request_count = 0
    
    226
    +            self.__file_request_size = 0
    
    227
    +            self.__file_response_size = 0
    
    228
    +
    
    229
    +    def close(self):
    
    230
    +        """Closes the underlying connection stubs.
    
    231
    +
    
    232
    +        Note:
    
    233
    +            This will always send pending requests before closing connections,
    
    234
    +            if any.
    
    235
    +        """
    
    236
    +        self.flush()
    
    237
    +
    
    238
    +        self.__bytestream_stub = None
    
    239
    +        self.__cas_stub = None
    
    240
    +
    
    241
    +    # --- Private API ---
    
    242
    +
    
    243
    +    def _fetch_blob(self, digest):
    
    244
    +        """Fetches a blob using ByteStream.Read()"""
    
    245
    +        read_blob = bytearray()
    
    246
    +
    
    247
    +        if self.instance_name is not None:
    
    248
    +            resource_name = '/'.join([self.instance_name, 'blobs',
    
    249
    +                                      digest.hash, str(digest.size_bytes)])
    
    250
    +        else:
    
    251
    +            resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
    
    252
    +
    
    253
    +        read_request = bytestream_pb2.ReadRequest()
    
    254
    +        read_request.resource_name = resource_name
    
    255
    +        read_request.read_offset = 0
    
    256
    +
    
    257
    +        try:
    
    258
    +            # TODO: Handle connection loss/recovery
    
    259
    +            for read_response in self.__bytestream_stub.Read(read_request):
    
    260
    +                read_blob += read_response.data
    
    261
    +
    
    262
    +            assert len(read_blob) == digest.size_bytes
    
    263
    +
    
    264
    +        except grpc.RpcError as e:
    
    265
    +            status_code = e.code()
    
    266
    +            if status_code == grpc.StatusCode.NOT_FOUND:
    
    267
    +                raise NotFoundError("Requested data does not exist on the remote.")
    
    268
    +
    
    269
    +            else:
    
    270
    +                assert False
    
    271
    +
    
    272
    +        return read_blob
    
    273
    +
    
    274
    +    def _fetch_blob_batch(self, digests):
    
    275
    +        """Fetches blobs using ContentAddressableStorage.BatchReadBlobs()"""
    
    276
    +        batch_fetched = False
    
    277
    +        read_blobs = []
    
    278
    +
    
    279
    +        # First, try BatchReadBlobs(), if not already known not being implemented:
    
    280
    +        if not _CallCache.unimplemented(self.channel, 'BatchReadBlobs'):
    
    281
    +            batch_request = remote_execution_pb2.BatchReadBlobsRequest()
    
    282
    +            batch_request.digests.extend(digests)
    
    283
    +            if self.instance_name is not None:
    
    284
    +                batch_request.instance_name = self.instance_name
    
    285
    +
    
    286
    +            try:
    
    287
    +                batch_response = self.__cas_stub.BatchReadBlobs(batch_request)
    
    288
    +                for response in batch_response.responses:
    
    289
    +                    assert response.digest.hash in digests
    
    290
    +
    
    291
    +                    read_blobs.append(response.data)
    
    292
    +
    
    293
    +                    if response.status.code != code_pb2.OK:
    
    294
    +                        assert False
    
    295
    +
    
    296
    +                batch_fetched = True
    
    297
    +
    
    298
    +            except grpc.RpcError as e:
    
    299
    +                status_code = e.code()
    
    300
    +                if status_code == grpc.StatusCode.UNIMPLEMENTED:
    
    301
    +                    _CallCache.mark_unimplemented(self.channel, 'BatchReadBlobs')
    
    302
    +
    
    303
    +                elif status_code == grpc.StatusCode.INVALID_ARGUMENT:
    
    304
    +                    read_blobs.clear()
    
    305
    +                    batch_fetched = False
    
    306
    +
    
    307
    +                else:
    
    308
    +                    assert False
    
    309
    +
    
    310
    +        # Fallback to Read() if no BatchReadBlobs():
    
    311
    +        if not batch_fetched:
    
    312
    +            for digest in digests:
    
    313
    +                read_blobs.append(self._fetch_blob(digest))
    
    314
    +
    
    315
    +        return read_blobs
    
    316
    +
    
    317
    +    def _fetch_file(self, digest, file_path):
    
    318
    +        """Fetches a file using ByteStream.Read()"""
    
    319
    +        if self.instance_name is not None:
    
    320
    +            resource_name = '/'.join([self.instance_name, 'blobs',
    
    321
    +                                      digest.hash, str(digest.size_bytes)])
    
    322
    +        else:
    
    323
    +            resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
    
    324
    +
    
    325
    +        read_request = bytestream_pb2.ReadRequest()
    
    326
    +        read_request.resource_name = resource_name
    
    327
    +        read_request.read_offset = 0
    
    328
    +
    
    329
    +        os.makedirs(os.path.dirname(file_path), exist_ok=True)
    
    330
    +
    
    331
    +        with open(file_path, 'wb') as byte_file:
    
    332
    +            # TODO: Handle connection loss/recovery
    
    333
    +            for read_response in self.__bytestream_stub.Read(read_request):
    
    334
    +                byte_file.write(read_response.data)
    
    335
    +
    
    336
    +            assert byte_file.tell() == digest.size_bytes
    
    337
    +
    
    338
    +    def _queue_file(self, digest, file_path):
    
    339
    +        """Queues a file for later batch download"""
    
    340
    +        if self.__file_request_size + digest.ByteSize() > MAX_REQUEST_SIZE:
    
    341
    +            self.flush()
    
    342
    +        elif self.__file_response_size + digest.size_bytes > MAX_REQUEST_SIZE:
    
    343
    +            self.flush()
    
    344
    +        elif self.__file_request_count >= MAX_REQUEST_COUNT:
    
    345
    +            self.flush()
    
    346
    +
    
    347
    +        self.__file_requests[digest.hash] = (digest, file_path)
    
    348
    +        self.__file_request_count += 1
    
    349
    +        self.__file_request_size += digest.ByteSize()
    
    350
    +        self.__file_response_size += digest.size_bytes
    
    351
    +
    
    352
    +    def _fetch_file_batch(self, batch):
    
    353
    +        """Sends queued data using ContentAddressableStorage.BatchReadBlobs()"""
    
    354
    +        batch_digests = [digest for digest, _ in batch.values()]
    
    355
    +        batch_blobs = self._fetch_blob_batch(batch_digests)
    
    356
    +
    
    357
    +        for (_, file_path), file_blob in zip(batch.values(), batch_blobs):
    
    358
    +            os.makedirs(os.path.dirname(file_path), exist_ok=True)
    
    359
    +
    
    360
    +            with open(file_path, 'wb') as byte_file:
    
    361
    +                byte_file.write(file_blob)
    
    362
    +
    
    363
    +    def _fetch_directory(self, digest, directory_path):
    
    364
    +        """Fetches a file using ByteStream.GetTree()"""
    
    365
    +        # Better fail early if the local root path cannot be created:
    
    366
    +        os.makedirs(directory_path, exist_ok=True)
    
    367
    +
    
    368
    +        directories = {}
    
    369
    +        directory_fetched = False
    
    370
    +        # First, try GetTree() if not known to be unimplemented yet:
    
    371
    +        if not _CallCache.unimplemented(self.channel, 'GetTree'):
    
    372
    +            tree_request = remote_execution_pb2.GetTreeRequest()
    
    373
    +            tree_request.root_digest.CopyFrom(digest)
    
    374
    +            tree_request.page_size = MAX_REQUEST_COUNT
    
    375
    +            if self.instance_name is not None:
    
    376
    +                tree_request.instance_name = self.instance_name
    
    377
    +
    
    378
    +            try:
    
    379
    +                for tree_response in self.__cas_stub.GetTree(tree_request):
    
    380
    +                    for directory in tree_response.directories:
    
    381
    +                        directory_blob = directory.SerializeToString()
    
    382
    +                        directory_hash = HASH(directory_blob).hexdigest()
    
    383
    +
    
    384
    +                        directories[directory_hash] = directory
    
    385
    +
    
    386
    +                assert digest.hash in directories
    
    387
    +
    
    388
    +                directory = directories[digest.hash]
    
    389
    +                self._write_directory(digest.hash, directory_path,
    
    390
    +                                      directories=directories, root_barrier=directory_path)
    
    391
    +
    
    392
    +                directory_fetched = True
    
    393
    +
    
    394
    +            except grpc.RpcError as e:
    
    395
    +                status_code = e.code()
    
    396
    +                if status_code == grpc.StatusCode.UNIMPLEMENTED:
    
    397
    +                    _CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs')
    
    398
    +
    
    399
    +                elif status_code == grpc.StatusCode.NOT_FOUND:
    
    400
    +                    raise NotFoundError("Requested directory does not exist on the remote.")
    
    401
    +
    
    402
    +                else:
    
    403
    +                    assert False
    
    404
    +
    
    405
    +        # TODO: Try with BatchReadBlobs().
    
    406
    +
    
    407
    +        # Fallback to Read() if no GetTree():
    
    408
    +        if not directory_fetched:
    
    409
    +            directory = remote_execution_pb2.Directory()
    
    410
    +            directory.ParseFromString(self._fetch_blob(digest))
    
    411
    +
    
    412
    +            self._write_directory(directory, directory_path,
    
    413
    +                                  root_barrier=directory_path)
    
    414
    +
    
    415
    +    def _write_directory(self, root_directory, root_path, directories=None, root_barrier=None):
    
    416
    +        """Generates a local directory structure"""
    
    417
    +        for file_node in root_directory.files:
    
    418
    +            file_path = os.path.join(root_path, file_node.name)
    
    419
    +
    
    420
    +            self._queue_file(file_node.digest, file_path)
    
    421
    +
    
    422
    +        for directory_node in root_directory.directories:
    
    423
    +            directory_path = os.path.join(root_path, directory_node.name)
    
    424
    +            if directories and directory_node.digest.hash in directories:
    
    425
    +                directory = directories[directory_node.digest.hash]
    
    426
    +            else:
    
    427
    +                directory = remote_execution_pb2.Directory()
    
    428
    +                directory.ParseFromString(self._fetch_blob(directory_node.digest))
    
    429
    +
    
    430
    +            os.makedirs(directory_path, exist_ok=True)
    
    431
    +
    
    432
    +            self._write_directory(directory, directory_path,
    
    433
    +                                  directories=directories, root_barrier=root_barrier)
    
    434
    +
    
    435
    +        for symlink_node in root_directory.symlinks:
    
    436
    +            symlink_path = os.path.join(root_path, symlink_node.name)
    
    437
    +            if not os.path.isabs(symlink_node.target):
    
    438
    +                target_path = os.path.join(root_path, symlink_node.target)
    
    439
    +            else:
    
    440
    +                target_path = symlink_node.target
    
    441
    +            target_path = os.path.normpath(target_path)
    
    442
    +
    
    443
    +            # Do not create links pointing outside the barrier:
    
    444
    +            if root_barrier is not None:
    
    445
    +                common_path = os.path.commonprefix([root_barrier, target_path])
    
    446
    +                if not common_path.startswith(root_barrier):
    
    447
    +                    continue
    
    448
    +
    
    449
    +            os.symlink(symlink_path, target_path)
    
    450
    +
    
    451
    +
    
    46 452
     @contextmanager
    
    47 453
     def upload(channel, instance=None, u_uid=None):
    
    48 454
         """Context manager generator for the :class:`Uploader` class."""
    
    ... ... @@ -63,16 +469,8 @@ class Uploader:
    63 469
     
    
    64 470
             with upload(channel, instance='build') as uploader:
    
    65 471
                 uploader.upload_file('/path/to/local/file')
    
    66
    -
    
    67
    -    Attributes:
    
    68
    -        FILE_SIZE_THRESHOLD (int): maximum size for a queueable file.
    
    69
    -        MAX_REQUEST_SIZE (int): maximum size for a single gRPC request.
    
    70 472
         """
    
    71 473
     
    
    72
    -    FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
    
    73
    -    MAX_REQUEST_SIZE = 2 * 1024 * 1024
    
    74
    -    MAX_REQUEST_COUNT = 500
    
    75
    -
    
    76 474
         def __init__(self, channel, instance=None, u_uid=None):
    
    77 475
             """Initializes a new :class:`Uploader` instance.
    
    78 476
     
    
    ... ... @@ -115,7 +513,7 @@ class Uploader:
    115 513
             Returns:
    
    116 514
                 :obj:`Digest`: the sent blob's digest.
    
    117 515
             """
    
    118
    -        if not queue or len(blob) > Uploader.FILE_SIZE_THRESHOLD:
    
    516
    +        if not queue or len(blob) > FILE_SIZE_THRESHOLD:
    
    119 517
                 blob_digest = self._send_blob(blob, digest=digest)
    
    120 518
             else:
    
    121 519
                 blob_digest = self._queue_blob(blob, digest=digest)
    
    ... ... @@ -141,7 +539,7 @@ class Uploader:
    141 539
             """
    
    142 540
             message_blob = message.SerializeToString()
    
    143 541
     
    
    144
    -        if not queue or len(message_blob) > Uploader.FILE_SIZE_THRESHOLD:
    
    542
    +        if not queue or len(message_blob) > FILE_SIZE_THRESHOLD:
    
    145 543
                 message_digest = self._send_blob(message_blob, digest=digest)
    
    146 544
             else:
    
    147 545
                 message_digest = self._queue_blob(message_blob, digest=digest)
    
    ... ... @@ -174,7 +572,7 @@ class Uploader:
    174 572
             with open(file_path, 'rb') as bytes_steam:
    
    175 573
                 file_bytes = bytes_steam.read()
    
    176 574
     
    
    177
    -        if not queue or len(file_bytes) > Uploader.FILE_SIZE_THRESHOLD:
    
    575
    +        if not queue or len(file_bytes) > FILE_SIZE_THRESHOLD:
    
    178 576
                 file_digest = self._send_blob(file_bytes)
    
    179 577
             else:
    
    180 578
                 file_digest = self._queue_blob(file_bytes)
    
    ... ... @@ -347,9 +745,9 @@ class Uploader:
    347 745
                 blob_digest.hash = HASH(blob).hexdigest()
    
    348 746
                 blob_digest.size_bytes = len(blob)
    
    349 747
     
    
    350
    -        if self.__request_size + blob_digest.size_bytes > Uploader.MAX_REQUEST_SIZE:
    
    748
    +        if self.__request_size + blob_digest.size_bytes > MAX_REQUEST_SIZE:
    
    351 749
                 self.flush()
    
    352
    -        elif self.__request_count >= Uploader.MAX_REQUEST_COUNT:
    
    750
    +        elif self.__request_count >= MAX_REQUEST_COUNT:
    
    353 751
                 self.flush()
    
    354 752
     
    
    355 753
             self.__requests[blob_digest.hash] = (blob, blob_digest)
    

  • buildgrid/server/cas/storage/remote.py
    ... ... @@ -23,14 +23,10 @@ Forwwards storage requests to a remote storage.
    23 23
     import io
    
    24 24
     import logging
    
    25 25
     
    
    26
    -import grpc
    
    27
    -
    
    28
    -from buildgrid.client.cas import upload
    
    29
    -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
    
    26
    +from buildgrid.client.cas import download, upload
    
    30 27
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    31 28
     from buildgrid._protos.google.rpc import code_pb2
    
    32 29
     from buildgrid._protos.google.rpc import status_pb2
    
    33
    -from buildgrid.utils import gen_fetch_blob
    
    34 30
     from buildgrid.settings import HASH
    
    35 31
     
    
    36 32
     from .storage_abc import StorageABC
    
    ... ... @@ -44,7 +40,6 @@ class RemoteStorage(StorageABC):
    44 40
             self.instance_name = instance_name
    
    45 41
             self.channel = channel
    
    46 42
     
    
    47
    -        self._stub_bs = bytestream_pb2_grpc.ByteStreamStub(channel)
    
    48 43
             self._stub_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(channel)
    
    49 44
     
    
    50 45
         def has_blob(self, digest):
    
    ... ... @@ -53,25 +48,12 @@ class RemoteStorage(StorageABC):
    53 48
             return False
    
    54 49
     
    
    55 50
         def get_blob(self, digest):
    
    56
    -        try:
    
    57
    -            fetched_data = io.BytesIO()
    
    58
    -            length = 0
    
    59
    -
    
    60
    -            for data in gen_fetch_blob(self._stub_bs, digest, self.instance_name):
    
    61
    -                length += fetched_data.write(data)
    
    62
    -
    
    63
    -            assert digest.size_bytes == length
    
    64
    -            fetched_data.seek(0)
    
    65
    -            return fetched_data
    
    66
    -
    
    67
    -        except grpc.RpcError as e:
    
    68
    -            if e.code() == grpc.StatusCode.NOT_FOUND:
    
    69
    -                pass
    
    51
    +        with download(self.channel, instance=self.instance_name) as downloader:
    
    52
    +            blob = downloader.get_blob(digest)
    
    53
    +            if blob is not None:
    
    54
    +                return io.BytesIO(blob)
    
    70 55
                 else:
    
    71
    -                self.logger.error(e.details())
    
    72
    -                raise
    
    73
    -
    
    74
    -        return None
    
    56
    +                return None
    
    75 57
     
    
    76 58
         def begin_write(self, digest):
    
    77 59
             return io.BytesIO()
    

  • buildgrid/utils.py
    ... ... @@ -18,87 +18,6 @@ import os
    18 18
     
    
    19 19
     from buildgrid.settings import HASH
    
    20 20
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    21
    -from buildgrid._protos.google.bytestream import bytestream_pb2
    
    22
    -
    
    23
    -
    
    24
    -def gen_fetch_blob(stub, digest, instance_name=""):
    
    25
    -    """ Generates byte stream from a fetch blob request
    
    26
    -    """
    
    27
    -
    
    28
    -    resource_name = os.path.join(instance_name, 'blobs', digest.hash, str(digest.size_bytes))
    
    29
    -    request = bytestream_pb2.ReadRequest(resource_name=resource_name,
    
    30
    -                                         read_offset=0)
    
    31
    -
    
    32
    -    for response in stub.Read(request):
    
    33
    -        yield response.data
    
    34
    -
    
    35
    -
    
    36
    -def write_fetch_directory(root_directory, stub, digest, instance_name=None):
    
    37
    -    """Locally replicates a directory from CAS.
    
    38
    -
    
    39
    -    Args:
    
    40
    -        root_directory (str): local directory to populate.
    
    41
    -        stub (): gRPC stub for CAS communication.
    
    42
    -        digest (Digest): digest for the directory to fetch from CAS.
    
    43
    -        instance_name (str, optional): farm instance name to query data from.
    
    44
    -    """
    
    45
    -
    
    46
    -    if not os.path.isabs(root_directory):
    
    47
    -        root_directory = os.path.abspath(root_directory)
    
    48
    -    if not os.path.exists(root_directory):
    
    49
    -        os.makedirs(root_directory, exist_ok=True)
    
    50
    -
    
    51
    -    directory = parse_to_pb2_from_fetch(remote_execution_pb2.Directory(),
    
    52
    -                                        stub, digest, instance_name)
    
    53
    -
    
    54
    -    for directory_node in directory.directories:
    
    55
    -        child_path = os.path.join(root_directory, directory_node.name)
    
    56
    -
    
    57
    -        write_fetch_directory(child_path, stub, directory_node.digest, instance_name)
    
    58
    -
    
    59
    -    for file_node in directory.files:
    
    60
    -        child_path = os.path.join(root_directory, file_node.name)
    
    61
    -
    
    62
    -        with open(child_path, 'wb') as child_file:
    
    63
    -            write_fetch_blob(child_file, stub, file_node.digest, instance_name)
    
    64
    -
    
    65
    -    for symlink_node in directory.symlinks:
    
    66
    -        child_path = os.path.join(root_directory, symlink_node.name)
    
    67
    -
    
    68
    -        if os.path.isabs(symlink_node.target):
    
    69
    -            continue  # No out of temp-directory links for now.
    
    70
    -        target_path = os.path.join(root_directory, symlink_node.target)
    
    71
    -
    
    72
    -        os.symlink(child_path, target_path)
    
    73
    -
    
    74
    -
    
    75
    -def write_fetch_blob(target_file, stub, digest, instance_name=None):
    
    76
    -    """Extracts a blob from CAS into a local file.
    
    77
    -
    
    78
    -    Args:
    
    79
    -        target_file (str): local file to write.
    
    80
    -        stub (): gRPC stub for CAS communication.
    
    81
    -        digest (Digest): digest for the blob to fetch from CAS.
    
    82
    -        instance_name (str, optional): farm instance name to query data from.
    
    83
    -    """
    
    84
    -
    
    85
    -    for stream in gen_fetch_blob(stub, digest, instance_name):
    
    86
    -        target_file.write(stream)
    
    87
    -    target_file.flush()
    
    88
    -
    
    89
    -    assert digest.size_bytes == os.fstat(target_file.fileno()).st_size
    
    90
    -
    
    91
    -
    
    92
    -def parse_to_pb2_from_fetch(pb2, stub, digest, instance_name=""):
    
    93
    -    """ Fetches stream and parses it into given pb2
    
    94
    -    """
    
    95
    -
    
    96
    -    stream_bytes = b''
    
    97
    -    for stream in gen_fetch_blob(stub, digest, instance_name):
    
    98
    -        stream_bytes += stream
    
    99
    -
    
    100
    -    pb2.ParseFromString(stream_bytes)
    
    101
    -    return pb2
    
    102 21
     
    
    103 22
     
    
    104 23
     def create_digest(bytes_to_digest):
    

  • tests/cas/test_client.py
    ... ... @@ -14,12 +14,15 @@
    14 14
     
    
    15 15
     # pylint: disable=redefined-outer-name
    
    16 16
     
    
    17
    +
    
    18
    +from copy import deepcopy
    
    17 19
     import os
    
    20
    +import tempfile
    
    18 21
     
    
    19 22
     import grpc
    
    20 23
     import pytest
    
    21 24
     
    
    22
    -from buildgrid.client.cas import upload
    
    25
    +from buildgrid.client.cas import download, upload
    
    23 26
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    24 27
     from buildgrid.utils import create_digest
    
    25 28
     
    
    ... ... @@ -41,6 +44,8 @@ FILES = [
    41 44
         (os.path.join(DATA_DIR, 'hello.cc'),),
    
    42 45
         (os.path.join(DATA_DIR, 'hello', 'hello.c'),
    
    43 46
          os.path.join(DATA_DIR, 'hello', 'hello.h'))]
    
    47
    +FOLDERS = [
    
    48
    +    (os.path.join(DATA_DIR, 'hello'),)]
    
    44 49
     DIRECTORIES = [
    
    45 50
         (os.path.join(DATA_DIR, 'hello'),),
    
    46 51
         (os.path.join(DATA_DIR, 'hello'), DATA_DIR)]
    
    ... ... @@ -214,3 +219,145 @@ def test_upload_tree(instance, directory_paths):
    214 219
                 directory_digest = create_digest(tree.root.SerializeToString())
    
    215 220
     
    
    216 221
                 assert server.compare_directories(directory_digest, directory_path)
    
    222
    +
    
    223
    +
    
    224
    +@pytest.mark.parametrize('blobs', BLOBS)
    
    225
    +@pytest.mark.parametrize('instance', INTANCES)
    
    226
    +def test_download_blob(instance, blobs):
    
    227
    +    # Actual test function, to be run in a subprocess:
    
    228
    +    def __test_download_blob(queue, remote, instance, digests):
    
    229
    +        # Open a channel to the remote CAS server:
    
    230
    +        channel = grpc.insecure_channel(remote)
    
    231
    +
    
    232
    +        blobs = []
    
    233
    +        with download(channel, instance) as downloader:
    
    234
    +            if len(digests) > 1:
    
    235
    +                blobs.extend(downloader.get_blobs(digests))
    
    236
    +            else:
    
    237
    +                blobs.append(downloader.get_blob(digests[0]))
    
    238
    +
    
    239
    +        queue.put(blobs)
    
    240
    +
    
    241
    +    # Start a minimal CAS server in a subprocess:
    
    242
    +    with serve_cas([instance]) as server:
    
    243
    +        digests = []
    
    244
    +        for blob in blobs:
    
    245
    +            digest = server.store_blob(blob)
    
    246
    +            digests.append(digest)
    
    247
    +
    
    248
    +        blobs = run_in_subprocess(__test_download_blob,
    
    249
    +                                  server.remote, instance, digests)
    
    250
    +
    
    251
    +        for digest, blob in zip(digests, blobs):
    
    252
    +            assert server.compare_blobs(digest, blob)
    
    253
    +
    
    254
    +
    
    255
    +@pytest.mark.parametrize('messages', MESSAGES)
    
    256
    +@pytest.mark.parametrize('instance', INTANCES)
    
    257
    +def test_download_message(instance, messages):
    
    258
    +    # Actual test function, to be run in a subprocess:
    
    259
    +    def __test_download_message(queue, remote, instance, digests, empty_messages):
    
    260
    +        # Open a channel to the remote CAS server:
    
    261
    +        channel = grpc.insecure_channel(remote)
    
    262
    +
    
    263
    +        messages = []
    
    264
    +        with download(channel, instance) as downloader:
    
    265
    +            if len(digests) > 1:
    
    266
    +                messages = downloader.get_messages(digests, empty_messages)
    
    267
    +                messages = list([m.SerializeToString() for m in messages])
    
    268
    +            else:
    
    269
    +                message = downloader.get_message(digests[0], empty_messages[0])
    
    270
    +                messages.append(message.SerializeToString())
    
    271
    +
    
    272
    +        queue.put(messages)
    
    273
    +
    
    274
    +    # Start a minimal CAS server in a subprocess:
    
    275
    +    with serve_cas([instance]) as server:
    
    276
    +        empty_messages, digests = [], []
    
    277
    +        for message in messages:
    
    278
    +            digest = server.store_message(message)
    
    279
    +            digests.append(digest)
    
    280
    +
    
    281
    +            empty_message = deepcopy(message)
    
    282
    +            empty_message.Clear()
    
    283
    +            empty_messages.append(empty_message)
    
    284
    +
    
    285
    +        messages = run_in_subprocess(__test_download_message,
    
    286
    +                                     server.remote, instance, digests, empty_messages)
    
    287
    +
    
    288
    +        for digest, message_blob, message in zip(digests, messages, empty_messages):
    
    289
    +            message.ParseFromString(message_blob)
    
    290
    +
    
    291
    +            assert server.compare_messages(digest, message)
    
    292
    +
    
    293
    +
    
    294
    +@pytest.mark.parametrize('file_paths', FILES)
    
    295
    +@pytest.mark.parametrize('instance', INTANCES)
    
    296
    +def test_download_file(instance, file_paths):
    
    297
    +    # Actual test function, to be run in a subprocess:
    
    298
    +    def __test_download_file(queue, remote, instance, digests, paths):
    
    299
    +        # Open a channel to the remote CAS server:
    
    300
    +        channel = grpc.insecure_channel(remote)
    
    301
    +
    
    302
    +        with download(channel, instance) as downloader:
    
    303
    +            if len(digests) > 1:
    
    304
    +                for digest, path in zip(digests, paths):
    
    305
    +                    downloader.download_file(digest, path, queue=False)
    
    306
    +            else:
    
    307
    +                downloader.download_file(digests[0], paths[0], queue=False)
    
    308
    +
    
    309
    +        queue.put(None)
    
    310
    +
    
    311
    +    # Start a minimal CAS server in a subprocess:
    
    312
    +    with serve_cas([instance]) as server:
    
    313
    +        with tempfile.TemporaryDirectory() as temp_folder:
    
    314
    +            paths, digests = [], []
    
    315
    +            for file_path in file_paths:
    
    316
    +                digest = server.store_file(file_path)
    
    317
    +                digests.append(digest)
    
    318
    +
    
    319
    +                path = os.path.relpath(file_path, start=DATA_DIR)
    
    320
    +                path = os.path.join(temp_folder, path)
    
    321
    +                paths.append(path)
    
    322
    +
    
    323
    +                run_in_subprocess(__test_download_file,
    
    324
    +                                  server.remote, instance, digests, paths)
    
    325
    +
    
    326
    +            for digest, path in zip(digests, paths):
    
    327
    +                assert server.compare_files(digest, path)
    
    328
    +
    
    329
    +
    
    330
    +@pytest.mark.parametrize('folder_paths', FOLDERS)
    
    331
    +@pytest.mark.parametrize('instance', INTANCES)
    
    332
    +def test_download_directory(instance, folder_paths):
    
    333
    +    # Actual test function, to be run in a subprocess:
    
    334
    +    def __test_download_directory(queue, remote, instance, digests, paths):
    
    335
    +        # Open a channel to the remote CAS server:
    
    336
    +        channel = grpc.insecure_channel(remote)
    
    337
    +
    
    338
    +        with download(channel, instance) as downloader:
    
    339
    +            if len(digests) > 1:
    
    340
    +                for digest, path in zip(digests, paths):
    
    341
    +                    downloader.download_directory(digest, path)
    
    342
    +            else:
    
    343
    +                downloader.download_directory(digests[0], paths[0])
    
    344
    +
    
    345
    +        queue.put(None)
    
    346
    +
    
    347
    +    # Start a minimal CAS server in a subprocess:
    
    348
    +    with serve_cas([instance]) as server:
    
    349
    +        with tempfile.TemporaryDirectory() as temp_folder:
    
    350
    +            paths, digests = [], []
    
    351
    +            for folder_path in folder_paths:
    
    352
    +                digest = server.store_folder(folder_path)
    
    353
    +                digests.append(digest)
    
    354
    +
    
    355
    +                path = os.path.relpath(folder_path, start=DATA_DIR)
    
    356
    +                path = os.path.join(temp_folder, path)
    
    357
    +                paths.append(path)
    
    358
    +
    
    359
    +                run_in_subprocess(__test_download_directory,
    
    360
    +                                  server.remote, instance, digests, paths)
    
    361
    +
    
    362
    +            for digest, path in zip(digests, paths):
    
    363
    +                assert server.compare_directories(digest, path)

  • tests/utils/cas.py
    ... ... @@ -30,6 +30,7 @@ from buildgrid.server.cas.service import ContentAddressableStorageService
    30 30
     from buildgrid.server.cas.instance import ByteStreamInstance
    
    31 31
     from buildgrid.server.cas.instance import ContentAddressableStorageInstance
    
    32 32
     from buildgrid.server.cas.storage.disk import DiskStorage
    
    33
    +from buildgrid.utils import create_digest, merkle_tree_maker
    
    33 34
     
    
    34 35
     
    
    35 36
     @contextmanager
    
    ... ... @@ -124,6 +125,15 @@ class Server:
    124 125
         def get(self, digest):
    
    125 126
             return self.__storage.get_blob(digest).read()
    
    126 127
     
    
    128
    +    def store_blob(self, blob):
    
    129
    +        digest = create_digest(blob)
    
    130
    +        write_buffer = self.__storage.begin_write(digest)
    
    131
    +        write_buffer.write(blob)
    
    132
    +
    
    133
    +        self.__storage.commit_write(digest, write_buffer)
    
    134
    +
    
    135
    +        return digest
    
    136
    +
    
    127 137
         def compare_blobs(self, digest, blob):
    
    128 138
             if not self.__storage.has_blob(digest):
    
    129 139
                 return False
    
    ... ... @@ -133,6 +143,16 @@ class Server:
    133 143
     
    
    134 144
             return blob == stored_blob
    
    135 145
     
    
    146
    +    def store_message(self, message):
    
    147
    +        message_blob = message.SerializeToString()
    
    148
    +        message_digest = create_digest(message_blob)
    
    149
    +        write_buffer = self.__storage.begin_write(message_digest)
    
    150
    +        write_buffer.write(message_blob)
    
    151
    +
    
    152
    +        self.__storage.commit_write(message_digest, write_buffer)
    
    153
    +
    
    154
    +        return message_digest
    
    155
    +
    
    136 156
         def compare_messages(self, digest, message):
    
    137 157
             if not self.__storage.has_blob(digest):
    
    138 158
                 return False
    
    ... ... @@ -144,6 +164,17 @@ class Server:
    144 164
     
    
    145 165
             return message_blob == stored_blob
    
    146 166
     
    
    167
    +    def store_file(self, file_path):
    
    168
    +        with open(file_path, 'rb') as file_bytes:
    
    169
    +            file_blob = file_bytes.read()
    
    170
    +        file_digest = create_digest(file_blob)
    
    171
    +        write_buffer = self.__storage.begin_write(file_digest)
    
    172
    +        write_buffer.write(file_blob)
    
    173
    +
    
    174
    +        self.__storage.commit_write(file_digest, write_buffer)
    
    175
    +
    
    176
    +        return file_digest
    
    177
    +
    
    147 178
         def compare_files(self, digest, file_path):
    
    148 179
             if not self.__storage.has_blob(digest):
    
    149 180
                 return False
    
    ... ... @@ -156,6 +187,17 @@ class Server:
    156 187
     
    
    157 188
             return file_blob == stored_blob
    
    158 189
     
    
    190
    +    def store_folder(self, folder_path):
    
    191
    +        last_digest = None
    
    192
    +        for node, blob, _ in merkle_tree_maker(folder_path):
    
    193
    +            write_buffer = self.__storage.begin_write(node.digest)
    
    194
    +            write_buffer.write(blob)
    
    195
    +
    
    196
    +            self.__storage.commit_write(node.digest, write_buffer)
    
    197
    +            last_digest = node.digest
    
    198
    +
    
    199
    +        return last_digest
    
    200
    +
    
    159 201
         def compare_directories(self, digest, directory_path):
    
    160 202
             if not self.__storage.has_blob(digest):
    
    161 203
                 return False
    



  • [Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]