[Notes] [Git][BuildGrid/buildgrid][master] 13 commits: client/cas.py: Rename the message uploading helper



Title: GitLab

Martin Blanchard pushed to branch master at BuildGrid / buildgrid

Commits:

17 changed files:

Changes:

  • buildgrid/_app/bots/buildbox.py
    ... ... @@ -119,7 +119,7 @@ def work_buildbox(context, lease):
    119 119
                 output_tree = _cas_tree_maker(stub_bytestream, output_digest)
    
    120 120
     
    
    121 121
                 with upload(context.cas_channel) as cas:
    
    122
    -                output_tree_digest = cas.send_message(output_tree)
    
    122
    +                output_tree_digest = cas.put_message(output_tree)
    
    123 123
     
    
    124 124
                 output_directory = remote_execution_pb2.OutputDirectory()
    
    125 125
                 output_directory.tree_digest.CopyFrom(output_tree_digest)
    

  • buildgrid/_app/bots/temp_directory.py
    ... ... @@ -94,15 +94,21 @@ def work_temp_directory(context, lease):
    94 94
             logger.debug("Command stdout: [{}]".format(stdout))
    
    95 95
             logger.debug("Command exit code: [{}]".format(returncode))
    
    96 96
     
    
    97
    -        with upload(context.cas_channel, instance=instance_name) as cas:
    
    97
    +        with upload(context.cas_channel, instance=instance_name) as uploader:
    
    98
    +            output_files, output_directories = [], []
    
    99
    +
    
    98 100
                 for output_path in command.output_files:
    
    99 101
                     file_path = os.path.join(working_directory, output_path)
    
    100 102
                     # Missing outputs should simply be omitted in ActionResult:
    
    101 103
                     if not os.path.isfile(file_path):
    
    102 104
                         continue
    
    103 105
     
    
    104
    -                output_file = output_file_maker(file_path, working_directory, cas=cas)
    
    105
    -                action_result.output_files.extend([output_file])
    
    106
    +                file_digest = uploader.upload_file(file_path, queue=True)
    
    107
    +                output_file = output_file_maker(file_path, working_directory,
    
    108
    +                                                file_digest)
    
    109
    +                output_files.append(output_file)
    
    110
    +
    
    111
    +            action_result.output_files.extend(output_files)
    
    106 112
     
    
    107 113
                 for output_path in command.output_directories:
    
    108 114
                     directory_path = os.path.join(working_directory, output_path)
    
    ... ... @@ -110,10 +116,12 @@ def work_temp_directory(context, lease):
    110 116
                     if not os.path.isdir(directory_path):
    
    111 117
                         continue
    
    112 118
     
    
    113
    -                # OutputDirectory.path should be relative to the working direcory:
    
    114
    -                output_directory = output_directory_maker(directory_path, working_directory, cas=cas)
    
    119
    +                tree_digest = uploader.upload_tree(directory_path, queue=True)
    
    120
    +                output_directory = output_directory_maker(directory_path, working_directory,
    
    121
    +                                                          tree_digest)
    
    122
    +                output_directories.append(output_directory)
    
    115 123
     
    
    116
    -                action_result.output_directories.extend([output_directory])
    
    124
    +            action_result.output_directories.extend(output_directories)
    
    117 125
     
    
    118 126
             action_result_any = any_pb2.Any()
    
    119 127
             action_result_any.Pack(action_result)
    

  • buildgrid/_app/commands/cmd_cas.py
    ... ... @@ -21,14 +21,16 @@ Request work to be executed and monitor status of jobs.
    21 21
     """
    
    22 22
     
    
    23 23
     import logging
    
    24
    +import os
    
    24 25
     import sys
    
    25 26
     from urllib.parse import urlparse
    
    26 27
     
    
    27 28
     import click
    
    28 29
     import grpc
    
    29 30
     
    
    30
    -from buildgrid.utils import merkle_maker, create_digest
    
    31
    -from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    31
    +from buildgrid.client.cas import upload
    
    32
    +from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    33
    +from buildgrid.utils import merkle_tree_maker
    
    32 34
     
    
    33 35
     from ..cli import pass_context
    
    34 36
     
    
    ... ... @@ -68,56 +70,62 @@ def cli(context, remote, instance_name, client_key, client_cert, server_cert):
    68 70
     @cli.command('upload-dummy', short_help="Upload a dummy action. Should be used with `execute dummy-request`")
    
    69 71
     @pass_context
    
    70 72
     def upload_dummy(context):
    
    71
    -    context.logger.info("Uploading dummy action...")
    
    72 73
         action = remote_execution_pb2.Action(do_not_cache=True)
    
    73
    -    action_digest = create_digest(action.SerializeToString())
    
    74
    +    with upload(context.channel, instance=context.instance_name) as uploader:
    
    75
    +        action_digest = uploader.put_message(action)
    
    74 76
     
    
    75
    -    request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=context.instance_name)
    
    76
    -    request.requests.add(digest=action_digest,
    
    77
    -                         data=action.SerializeToString())
    
    78
    -
    
    79
    -    stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(context.channel)
    
    80
    -    response = stub.BatchUpdateBlobs(request)
    
    81
    -
    
    82
    -    context.logger.info(response)
    
    77
    +    if action_digest.ByteSize():
    
    78
    +        click.echo('Success: Pushed digest "{}/{}"'
    
    79
    +                   .format(action_digest.hash, action_digest.size_bytes))
    
    80
    +    else:
    
    81
    +        click.echo("Error: Failed pushing empty message.", err=True)
    
    83 82
     
    
    84 83
     
    
    85 84
     @cli.command('upload-files', short_help="Upload files to the CAS server.")
    
    86
    -@click.argument('files', nargs=-1, type=click.File('rb'), required=True)
    
    85
    +@click.argument('files', nargs=-1, type=click.Path(exists=True, dir_okay=False), required=True)
    
    87 86
     @pass_context
    
    88 87
     def upload_files(context, files):
    
    89
    -    stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(context.channel)
    
    88
    +    sent_digests, files_map = [], {}
    
    89
    +    with upload(context.channel, instance=context.instance_name) as uploader:
    
    90
    +        for file_path in files:
    
    91
    +            context.logger.debug("Queueing {}".format(file_path))
    
    90 92
     
    
    91
    -    requests = []
    
    92
    -    for file in files:
    
    93
    -        chunk = file.read()
    
    94
    -        requests.append(remote_execution_pb2.BatchUpdateBlobsRequest.Request(
    
    95
    -            digest=create_digest(chunk), data=chunk))
    
    93
    +            file_digest = uploader.upload_file(file_path, queue=True)
    
    96 94
     
    
    97
    -    request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=context.instance_name,
    
    98
    -                                                           requests=requests)
    
    95
    +            files_map[file_digest.hash] = file_path
    
    96
    +            sent_digests.append(file_digest)
    
    99 97
     
    
    100
    -    context.logger.info("Sending: {}".format(request))
    
    101
    -    response = stub.BatchUpdateBlobs(request)
    
    102
    -    context.logger.info("Response: {}".format(response))
    
    98
    +    for file_digest in sent_digests:
    
    99
    +        file_path = files_map[file_digest.hash]
    
    100
    +        if os.path.isabs(file_path):
    
    101
    +            file_path = os.path.relpath(file_path)
    
    102
    +        if file_digest.ByteSize():
    
    103
    +            click.echo('Success: Pushed "{}" with digest "{}/{}"'
    
    104
    +                       .format(file_path, file_digest.hash, file_digest.size_bytes))
    
    105
    +        else:
    
    106
    +            click.echo('Error: Failed to push "{}"'.format(file_path), err=True)
    
    103 107
     
    
    104 108
     
    
    105 109
     @cli.command('upload-dir', short_help="Upload a directory to the CAS server.")
    
    106
    -@click.argument('directory', nargs=1, type=click.Path(), required=True)
    
    110
    +@click.argument('directory', nargs=1, type=click.Path(exists=True, file_okay=False), required=True)
    
    107 111
     @pass_context
    
    108 112
     def upload_dir(context, directory):
    
    109
    -    context.logger.info("Uploading directory to cas")
    
    110
    -    stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(context.channel)
    
    111
    -
    
    112
    -    requests = []
    
    113
    -
    
    114
    -    for chunk, file_digest in merkle_maker(directory):
    
    115
    -        requests.append(remote_execution_pb2.BatchUpdateBlobsRequest.Request(
    
    116
    -            digest=file_digest, data=chunk))
    
    117
    -
    
    118
    -    request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=context.instance_name,
    
    119
    -                                                           requests=requests)
    
    120
    -
    
    121
    -    context.logger.info("Request:\n{}".format(request))
    
    122
    -    response = stub.BatchUpdateBlobs(request)
    
    123
    -    context.logger.info("Response:\n{}".format(response))
    113
    +    sent_digests, nodes_map = [], {}
    
    114
    +    with upload(context.channel, instance=context.instance_name) as uploader:
    
    115
    +        for node, blob, path in merkle_tree_maker(directory):
    
    116
    +            context.logger.debug("Queueing {}".format(path))
    
    117
    +
    
    118
    +            node_digest = uploader.put_blob(blob, digest=node.digest, queue=True)
    
    119
    +
    
    120
    +            nodes_map[node.digest.hash] = path
    
    121
    +            sent_digests.append(node_digest)
    
    122
    +
    
    123
    +    for node_digest in sent_digests:
    
    124
    +        node_path = nodes_map[node_digest.hash]
    
    125
    +        if os.path.isabs(node_path):
    
    126
    +            node_path = os.path.relpath(node_path, start=directory)
    
    127
    +        if node_digest.ByteSize():
    
    128
    +            click.echo('Success: Pushed "{}" with digest "{}/{}"'
    
    129
    +                       .format(node_path, node_digest.hash, node_digest.size_bytes))
    
    130
    +        else:
    
    131
    +            click.echo('Error: Failed to push "{}"'.format(node_path), err=True)

  • buildgrid/_app/commands/cmd_execute.py
    ... ... @@ -30,9 +30,10 @@ from urllib.parse import urlparse
    30 30
     import click
    
    31 31
     import grpc
    
    32 32
     
    
    33
    -from buildgrid.utils import merkle_maker, create_digest, write_fetch_blob
    
    33
    +from buildgrid.client.cas import upload
    
    34 34
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    35 35
     from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
    
    36
    +from buildgrid.utils import create_digest, write_fetch_blob
    
    36 37
     
    
    37 38
     from ..cli import pass_context
    
    38 39
     
    
    ... ... @@ -87,7 +88,7 @@ def request_dummy(context, number, wait_for_completion):
    87 88
                                                       action_digest=action_digest,
    
    88 89
                                                       skip_cache_lookup=True)
    
    89 90
     
    
    90
    -    responses = list()
    
    91
    +    responses = []
    
    91 92
         for _ in range(0, number):
    
    92 93
             responses.append(stub.Execute(request))
    
    93 94
     
    
    ... ... @@ -116,46 +117,37 @@ def request_dummy(context, number, wait_for_completion):
    116 117
     @click.argument('input-root', nargs=1, type=click.Path(), required=True)
    
    117 118
     @click.argument('commands', nargs=-1, type=click.STRING, required=True)
    
    118 119
     @pass_context
    
    119
    -def command(context, input_root, commands, output_file, output_directory):
    
    120
    +def run_command(context, input_root, commands, output_file, output_directory):
    
    120 121
         stub = remote_execution_pb2_grpc.ExecutionStub(context.channel)
    
    121 122
     
    
    122
    -    execute_command = remote_execution_pb2.Command()
    
    123
    -
    
    124
    -    for arg in commands:
    
    125
    -        execute_command.arguments.extend([arg])
    
    126
    -
    
    127 123
         output_executeables = []
    
    128
    -    for file, is_executeable in output_file:
    
    129
    -        execute_command.output_files.extend([file])
    
    130
    -        if is_executeable:
    
    131
    -            output_executeables.append(file)
    
    124
    +    with upload(context.channel, instance=context.instance_name) as uploader:
    
    125
    +        command = remote_execution_pb2.Command()
    
    132 126
     
    
    133
    -    command_digest = create_digest(execute_command.SerializeToString())
    
    134
    -    context.logger.info(command_digest)
    
    127
    +        for arg in commands:
    
    128
    +            command.arguments.extend([arg])
    
    135 129
     
    
    136
    -    # TODO: Check for missing blobs
    
    137
    -    digest = None
    
    138
    -    for _, digest in merkle_maker(input_root):
    
    139
    -        pass
    
    130
    +        for file, is_executeable in output_file:
    
    131
    +            command.output_files.extend([file])
    
    132
    +            if is_executeable:
    
    133
    +                output_executeables.append(file)
    
    140 134
     
    
    141
    -    action = remote_execution_pb2.Action(command_digest=command_digest,
    
    142
    -                                         input_root_digest=digest,
    
    143
    -                                         do_not_cache=True)
    
    135
    +        command_digest = uploader.put_message(command, queue=True)
    
    144 136
     
    
    145
    -    action_digest = create_digest(action.SerializeToString())
    
    137
    +        context.logger.info('Sent command: {}'.format(command_digest))
    
    146 138
     
    
    147
    -    context.logger.info("Sending execution request...")
    
    139
    +        # TODO: Check for missing blobs
    
    140
    +        input_root_digest = uploader.upload_directory(input_root)
    
    141
    +
    
    142
    +        context.logger.info('Sent input: {}'.format(input_root_digest))
    
    148 143
     
    
    149
    -    requests = []
    
    150
    -    requests.append(remote_execution_pb2.BatchUpdateBlobsRequest.Request(
    
    151
    -        digest=command_digest, data=execute_command.SerializeToString()))
    
    144
    +        action = remote_execution_pb2.Action(command_digest=command_digest,
    
    145
    +                                             input_root_digest=input_root_digest,
    
    146
    +                                             do_not_cache=True)
    
    152 147
     
    
    153
    -    requests.append(remote_execution_pb2.BatchUpdateBlobsRequest.Request(
    
    154
    -        digest=action_digest, data=action.SerializeToString()))
    
    148
    +        action_digest = uploader.put_message(action, queue=True)
    
    155 149
     
    
    156
    -    request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=context.instance_name,
    
    157
    -                                                           requests=requests)
    
    158
    -    remote_execution_pb2_grpc.ContentAddressableStorageStub(context.channel).BatchUpdateBlobs(request)
    
    150
    +        context.logger.info("Sent action: {}".format(action_digest))
    
    159 151
     
    
    160 152
         request = remote_execution_pb2.ExecuteRequest(instance_name=context.instance_name,
    
    161 153
                                                       action_digest=action_digest,
    

  • buildgrid/client/cas.py
    ... ... @@ -17,18 +17,40 @@ from contextlib import contextmanager
    17 17
     import uuid
    
    18 18
     import os
    
    19 19
     
    
    20
    -from buildgrid.settings import HASH
    
    20
    +import grpc
    
    21
    +
    
    21 22
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    22 23
     from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
    
    24
    +from buildgrid._protos.google.rpc import code_pb2
    
    25
    +from buildgrid.settings import HASH
    
    26
    +from buildgrid.utils import merkle_tree_maker
    
    27
    +
    
    28
    +
    
    29
    +class _CallCache:
    
    30
    +    """Per remote grpc.StatusCode.UNIMPLEMENTED call cache."""
    
    31
    +    __calls = {}
    
    32
    +
    
    33
    +    @classmethod
    
    34
    +    def mark_unimplemented(cls, channel, name):
    
    35
    +        if channel not in cls.__calls:
    
    36
    +            cls.__calls[channel] = set()
    
    37
    +        cls.__calls[channel].add(name)
    
    38
    +
    
    39
    +    @classmethod
    
    40
    +    def unimplemented(cls, channel, name):
    
    41
    +        if channel not in cls.__calls:
    
    42
    +            return False
    
    43
    +        return name in cls.__calls[channel]
    
    23 44
     
    
    24 45
     
    
    25 46
     @contextmanager
    
    26 47
     def upload(channel, instance=None, u_uid=None):
    
    48
    +    """Context manager generator for the :class:`Uploader` class."""
    
    27 49
         uploader = Uploader(channel, instance=instance, u_uid=u_uid)
    
    28 50
         try:
    
    29 51
             yield uploader
    
    30 52
         finally:
    
    31
    -        uploader.flush()
    
    53
    +        uploader.close()
    
    32 54
     
    
    33 55
     
    
    34 56
     class Uploader:
    
    ... ... @@ -37,8 +59,10 @@ class Uploader:
    37 59
         The :class:`Uploader` class comes with a generator factory function that can
    
    38 60
         be used together with the `with` statement for context management::
    
    39 61
     
    
    40
    -        with upload(channel, instance='build') as cas:
    
    41
    -            cas.upload_file('/path/to/local/file')
    
    62
    +        from buildgrid.client.cas import upload
    
    63
    +
    
    64
    +        with upload(channel, instance='build') as uploader:
    
    65
    +            uploader.upload_file('/path/to/local/file')
    
    42 66
     
    
    43 67
         Attributes:
    
    44 68
             FILE_SIZE_THRESHOLD (int): maximum size for a queueable file.
    
    ... ... @@ -47,6 +71,7 @@ class Uploader:
    47 71
     
    
    48 72
         FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
    
    49 73
         MAX_REQUEST_SIZE = 2 * 1024 * 1024
    
    74
    +    MAX_REQUEST_COUNT = 500
    
    50 75
     
    
    51 76
         def __init__(self, channel, instance=None, u_uid=None):
    
    52 77
             """Initializes a new :class:`Uploader` instance.
    
    ... ... @@ -67,19 +92,72 @@ class Uploader:
    67 92
             self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel)
    
    68 93
             self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
    
    69 94
     
    
    70
    -        self.__requests = dict()
    
    95
    +        self.__requests = {}
    
    96
    +        self.__request_count = 0
    
    71 97
             self.__request_size = 0
    
    72 98
     
    
    99
    +    # --- Public API ---
    
    100
    +
    
    101
    +    def put_blob(self, blob, digest=None, queue=False):
    
    102
    +        """Stores a blob into the remote CAS server.
    
    103
    +
    
    104
    +        If queuing is allowed (`queue=True`), the upload request **may** be
    
    105
    +        defer. An explicit call to :func:`~flush` can force the request to be
    
    106
    +        send immediately (along with the rest of the queued batch).
    
    107
    +
    
    108
    +        Args:
    
    109
    +            blob (bytes): the blob's data.
    
    110
    +            digest (:obj:`Digest`, optional): the blob's digest.
    
    111
    +            queue (bool, optional): whether or not the upload request may be
    
    112
    +                queued and submitted as part of a batch upload request. Defaults
    
    113
    +                to False.
    
    114
    +
    
    115
    +        Returns:
    
    116
    +            :obj:`Digest`: the sent blob's digest.
    
    117
    +        """
    
    118
    +        if not queue or len(blob) > Uploader.FILE_SIZE_THRESHOLD:
    
    119
    +            blob_digest = self._send_blob(blob, digest=digest)
    
    120
    +        else:
    
    121
    +            blob_digest = self._queue_blob(blob, digest=digest)
    
    122
    +
    
    123
    +        return blob_digest
    
    124
    +
    
    125
    +    def put_message(self, message, digest=None, queue=False):
    
    126
    +        """Stores a message into the remote CAS server.
    
    127
    +
    
    128
    +        If queuing is allowed (`queue=True`), the upload request **may** be
    
    129
    +        defer. An explicit call to :func:`~flush` can force the request to be
    
    130
    +        send immediately (along with the rest of the queued batch).
    
    131
    +
    
    132
    +        Args:
    
    133
    +            message (:obj:`Message`): the message object.
    
    134
    +            digest (:obj:`Digest`, optional): the message's digest.
    
    135
    +            queue (bool, optional): whether or not the upload request may be
    
    136
    +                queued and submitted as part of a batch upload request. Defaults
    
    137
    +                to False.
    
    138
    +
    
    139
    +        Returns:
    
    140
    +            :obj:`Digest`: the sent message's digest.
    
    141
    +        """
    
    142
    +        message_blob = message.SerializeToString()
    
    143
    +
    
    144
    +        if not queue or len(message_blob) > Uploader.FILE_SIZE_THRESHOLD:
    
    145
    +            message_digest = self._send_blob(message_blob, digest=digest)
    
    146
    +        else:
    
    147
    +            message_digest = self._queue_blob(message_blob, digest=digest)
    
    148
    +
    
    149
    +        return message_digest
    
    150
    +
    
    73 151
         def upload_file(self, file_path, queue=True):
    
    74 152
             """Stores a local file into the remote CAS storage.
    
    75 153
     
    
    76 154
             If queuing is allowed (`queue=True`), the upload request **may** be
    
    77
    -        defer. An explicit call to :method:`flush` can force the request to be
    
    155
    +        defer. An explicit call to :func:`~flush` can force the request to be
    
    78 156
             send immediately (allong with the rest of the queued batch).
    
    79 157
     
    
    80 158
             Args:
    
    81 159
                 file_path (str): absolute or relative path to a local file.
    
    82
    -            queue (bool, optional): wheter or not the upload request may be
    
    160
    +            queue (bool, optional): whether or not the upload request may be
    
    83 161
                     queued and submitted as part of a batch upload request. Defaults
    
    84 162
                     to True.
    
    85 163
     
    
    ... ... @@ -87,7 +165,8 @@ class Uploader:
    87 165
                 :obj:`Digest`: The digest of the file's content.
    
    88 166
     
    
    89 167
             Raises:
    
    90
    -            OSError: If `file_path` does not exist or is not readable.
    
    168
    +            FileNotFoundError: If `file_path` does not exist.
    
    169
    +            PermissionError: If `file_path` is not readable.
    
    91 170
             """
    
    92 171
             if not os.path.isabs(file_path):
    
    93 172
                 file_path = os.path.abspath(file_path)
    
    ... ... @@ -96,80 +175,135 @@ class Uploader:
    96 175
                 file_bytes = bytes_steam.read()
    
    97 176
     
    
    98 177
             if not queue or len(file_bytes) > Uploader.FILE_SIZE_THRESHOLD:
    
    99
    -            blob_digest = self._send_blob(file_bytes)
    
    178
    +            file_digest = self._send_blob(file_bytes)
    
    100 179
             else:
    
    101
    -            blob_digest = self._queue_blob(file_bytes)
    
    180
    +            file_digest = self._queue_blob(file_bytes)
    
    102 181
     
    
    103
    -        return blob_digest
    
    182
    +        return file_digest
    
    104 183
     
    
    105
    -    def upload_directory(self, directory, queue=True):
    
    106
    -        """Stores a :obj:`Directory` into the remote CAS storage.
    
    184
    +    def upload_directory(self, directory_path, queue=True):
    
    185
    +        """Stores a local folder into the remote CAS storage.
    
    107 186
     
    
    108 187
             If queuing is allowed (`queue=True`), the upload request **may** be
    
    109
    -        defer. An explicit call to :method:`flush` can force the request to be
    
    188
    +        defer. An explicit call to :func:`~flush` can force the request to be
    
    110 189
             send immediately (allong with the rest of the queued batch).
    
    111 190
     
    
    112 191
             Args:
    
    113
    -            directory (:obj:`Directory`): a :obj:`Directory` object.
    
    114
    -            queue (bool, optional): wheter or not the upload request may be
    
    192
    +            directory_path (str): absolute or relative path to a local folder.
    
    193
    +            queue (bool, optional): wheter or not the upload requests may be
    
    115 194
                     queued and submitted as part of a batch upload request. Defaults
    
    116 195
                     to True.
    
    117 196
     
    
    118 197
             Returns:
    
    119
    -            :obj:`Digest`: The digest of the :obj:`Directory`.
    
    198
    +            :obj:`Digest`: The digest of the top :obj:`Directory`.
    
    199
    +
    
    200
    +        Raises:
    
    201
    +            FileNotFoundError: If `directory_path` does not exist.
    
    202
    +            PermissionError: If `directory_path` is not readable.
    
    120 203
             """
    
    121
    -        if not isinstance(directory, remote_execution_pb2.Directory):
    
    122
    -            raise TypeError
    
    204
    +        if not os.path.isabs(directory_path):
    
    205
    +            directory_path = os.path.abspath(directory_path)
    
    206
    +
    
    207
    +        last_directory_node = None
    
    123 208
     
    
    124 209
             if not queue:
    
    125
    -            return self._send_blob(directory.SerializeToString())
    
    210
    +            for node, blob, _ in merkle_tree_maker(directory_path):
    
    211
    +                if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR:
    
    212
    +                    last_directory_node = node
    
    213
    +
    
    214
    +                self._send_blob(blob, digest=node.digest)
    
    215
    +
    
    126 216
             else:
    
    127
    -            return self._queue_blob(directory.SerializeToString())
    
    217
    +            for node, blob, _ in merkle_tree_maker(directory_path):
    
    218
    +                if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR:
    
    219
    +                    last_directory_node = node
    
    220
    +
    
    221
    +                self._queue_blob(blob, digest=node.digest)
    
    222
    +
    
    223
    +        return last_directory_node.digest
    
    128 224
     
    
    129
    -    def send_message(self, message):
    
    130
    -        """Stores a message into the remote CAS storage.
    
    225
    +    def upload_tree(self, directory_path, queue=True):
    
    226
    +        """Stores a local folder into the remote CAS storage as a :obj:`Tree`.
    
    227
    +
    
    228
    +        If queuing is allowed (`queue=True`), the upload request **may** be
    
    229
    +        defer. An explicit call to :func:`~flush` can force the request to be
    
    230
    +        send immediately (allong with the rest of the queued batch).
    
    131 231
     
    
    132 232
             Args:
    
    133
    -            message (:obj:`Message`): a protobuf message object.
    
    233
    +            directory_path (str): absolute or relative path to a local folder.
    
    234
    +            queue (bool, optional): wheter or not the upload requests may be
    
    235
    +                queued and submitted as part of a batch upload request. Defaults
    
    236
    +                to True.
    
    134 237
     
    
    135 238
             Returns:
    
    136
    -            :obj:`Digest`: The digest of the message.
    
    239
    +            :obj:`Digest`: The digest of the :obj:`Tree`.
    
    240
    +
    
    241
    +        Raises:
    
    242
    +            FileNotFoundError: If `directory_path` does not exist.
    
    243
    +            PermissionError: If `directory_path` is not readable.
    
    137 244
             """
    
    138
    -        return self._send_blob(message.SerializeToString())
    
    245
    +        if not os.path.isabs(directory_path):
    
    246
    +            directory_path = os.path.abspath(directory_path)
    
    247
    +
    
    248
    +        directories = []
    
    249
    +
    
    250
    +        if not queue:
    
    251
    +            for node, blob, _ in merkle_tree_maker(directory_path):
    
    252
    +                if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR:
    
    253
    +                    # TODO: Get the Directory object from merkle_tree_maker():
    
    254
    +                    directory = remote_execution_pb2.Directory()
    
    255
    +                    directory.ParseFromString(blob)
    
    256
    +                    directories.append(directory)
    
    257
    +
    
    258
    +                self._send_blob(blob, digest=node.digest)
    
    259
    +
    
    260
    +        else:
    
    261
    +            for node, blob, _ in merkle_tree_maker(directory_path):
    
    262
    +                if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR:
    
    263
    +                    # TODO: Get the Directory object from merkle_tree_maker():
    
    264
    +                    directory = remote_execution_pb2.Directory()
    
    265
    +                    directory.ParseFromString(blob)
    
    266
    +                    directories.append(directory)
    
    267
    +
    
    268
    +                self._queue_blob(blob, digest=node.digest)
    
    269
    +
    
    270
    +        tree = remote_execution_pb2.Tree()
    
    271
    +        tree.root.CopyFrom(directories[-1])
    
    272
    +        tree.children.extend(directories[:-1])
    
    273
    +
    
    274
    +        return self.put_message(tree, queue=queue)
    
    139 275
     
    
    140 276
         def flush(self):
    
    141 277
             """Ensures any queued request gets sent."""
    
    142 278
             if self.__requests:
    
    143
    -            self._send_batch()
    
    279
    +            self._send_blob_batch(self.__requests)
    
    144 280
     
    
    145
    -    def _queue_blob(self, blob):
    
    146
    -        """Queues a memory block for later batch upload"""
    
    147
    -        blob_digest = remote_execution_pb2.Digest()
    
    148
    -        blob_digest.hash = HASH(blob).hexdigest()
    
    149
    -        blob_digest.size_bytes = len(blob)
    
    281
    +            self.__requests.clear()
    
    282
    +            self.__request_count = 0
    
    283
    +            self.__request_size = 0
    
    150 284
     
    
    151
    -        if self.__request_size + len(blob) > Uploader.MAX_REQUEST_SIZE:
    
    152
    -            self._send_batch()
    
    285
    +    def close(self):
    
    286
    +        """Closes the underlying connection stubs.
    
    153 287
     
    
    154
    -        update_request = remote_execution_pb2.BatchUpdateBlobsRequest.Request()
    
    155
    -        update_request.digest.CopyFrom(blob_digest)
    
    156
    -        update_request.data = blob
    
    157
    -
    
    158
    -        update_request_size = update_request.ByteSize()
    
    159
    -        if self.__request_size + update_request_size > Uploader.MAX_REQUEST_SIZE:
    
    160
    -            self._send_batch()
    
    288
    +        Note:
    
    289
    +            This will always send pending requests before closing connections,
    
    290
    +            if any.
    
    291
    +        """
    
    292
    +        self.flush()
    
    161 293
     
    
    162
    -        self.__requests[update_request.digest.hash] = update_request
    
    163
    -        self.__request_size += update_request_size
    
    294
    +        self.__bytestream_stub = None
    
    295
    +        self.__cas_stub = None
    
    164 296
     
    
    165
    -        return blob_digest
    
    297
    +    # --- Private API ---
    
    166 298
     
    
    167
    -    def _send_blob(self, blob):
    
    299
    +    def _send_blob(self, blob, digest=None):
    
    168 300
             """Sends a memory block using ByteStream.Write()"""
    
    169 301
             blob_digest = remote_execution_pb2.Digest()
    
    170
    -        blob_digest.hash = HASH(blob).hexdigest()
    
    171
    -        blob_digest.size_bytes = len(blob)
    
    172
    -
    
    302
    +        if digest is not None:
    
    303
    +            blob_digest.CopyFrom(digest)
    
    304
    +        else:
    
    305
    +            blob_digest.hash = HASH(blob).hexdigest()
    
    306
    +            blob_digest.size_bytes = len(blob)
    
    173 307
             if self.instance_name is not None:
    
    174 308
                 resource_name = '/'.join([self.instance_name, 'uploads', self.u_uid, 'blobs',
    
    175 309
                                           blob_digest.hash, str(blob_digest.size_bytes)])
    
    ... ... @@ -182,7 +316,7 @@ class Uploader:
    182 316
                 finished = False
    
    183 317
                 remaining = len(content)
    
    184 318
                 while not finished:
    
    185
    -                chunk_size = min(remaining, 64 * 1024)
    
    319
    +                chunk_size = min(remaining, Uploader.MAX_REQUEST_SIZE)
    
    186 320
                     remaining -= chunk_size
    
    187 321
     
    
    188 322
                     request = bytestream_pb2.WriteRequest()
    
    ... ... @@ -204,18 +338,68 @@ class Uploader:
    204 338
     
    
    205 339
             return blob_digest
    
    206 340
     
    
    207
    -    def _send_batch(self):
    
    341
    +    def _queue_blob(self, blob, digest=None):
    
    342
    +        """Queues a memory block for later batch upload"""
    
    343
    +        blob_digest = remote_execution_pb2.Digest()
    
    344
    +        if digest is not None:
    
    345
    +            blob_digest.CopyFrom(digest)
    
    346
    +        else:
    
    347
    +            blob_digest.hash = HASH(blob).hexdigest()
    
    348
    +            blob_digest.size_bytes = len(blob)
    
    349
    +
    
    350
    +        if self.__request_size + blob_digest.size_bytes > Uploader.MAX_REQUEST_SIZE:
    
    351
    +            self.flush()
    
    352
    +        elif self.__request_count >= Uploader.MAX_REQUEST_COUNT:
    
    353
    +            self.flush()
    
    354
    +
    
    355
    +        self.__requests[blob_digest.hash] = (blob, blob_digest)
    
    356
    +        self.__request_count += 1
    
    357
    +        self.__request_size += blob_digest.size_bytes
    
    358
    +
    
    359
    +        return blob_digest
    
    360
    +
    
    361
    +    def _send_blob_batch(self, batch):
    
    208 362
             """Sends queued data using ContentAddressableStorage.BatchUpdateBlobs()"""
    
    209
    -        batch_request = remote_execution_pb2.BatchUpdateBlobsRequest()
    
    210
    -        batch_request.requests.extend(self.__requests.values())
    
    211
    -        if self.instance_name is not None:
    
    212
    -            batch_request.instance_name = self.instance_name
    
    363
    +        batch_fetched = False
    
    364
    +        written_digests = []
    
    213 365
     
    
    214
    -        batch_response = self.__cas_stub.BatchUpdateBlobs(batch_request)
    
    366
    +        # First, try BatchUpdateBlobs(), if not already known not being implemented:
    
    367
    +        if not _CallCache.unimplemented(self.channel, 'BatchUpdateBlobs'):
    
    368
    +            batch_request = remote_execution_pb2.BatchUpdateBlobsRequest()
    
    369
    +            if self.instance_name is not None:
    
    370
    +                batch_request.instance_name = self.instance_name
    
    215 371
     
    
    216
    -        for response in batch_response.responses:
    
    217
    -            assert response.digest.hash in self.__requests
    
    218
    -            assert response.status.code is 0
    
    372
    +            for blob, digest in batch.values():
    
    373
    +                request = batch_request.requests.add()
    
    374
    +                request.digest.CopyFrom(digest)
    
    375
    +                request.data = blob
    
    219 376
     
    
    220
    -        self.__requests.clear()
    
    221
    -        self.__request_size = 0
    377
    +            try:
    
    378
    +                batch_response = self.__cas_stub.BatchUpdateBlobs(batch_request)
    
    379
    +                for response in batch_response.responses:
    
    380
    +                    assert response.digest.hash in batch
    
    381
    +
    
    382
    +                    written_digests.append(response.digest)
    
    383
    +                    if response.status.code != code_pb2.OK:
    
    384
    +                        response.digest.Clear()
    
    385
    +
    
    386
    +                batch_fetched = True
    
    387
    +
    
    388
    +            except grpc.RpcError as e:
    
    389
    +                status_code = e.code()
    
    390
    +                if status_code == grpc.StatusCode.UNIMPLEMENTED:
    
    391
    +                    _CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs')
    
    392
    +
    
    393
    +                elif status_code == grpc.StatusCode.INVALID_ARGUMENT:
    
    394
    +                    written_digests.clear()
    
    395
    +                    batch_fetched = False
    
    396
    +
    
    397
    +                else:
    
    398
    +                    assert False
    
    399
    +
    
    400
    +        # Fallback to Write() if no BatchUpdateBlobs():
    
    401
    +        if not batch_fetched:
    
    402
    +            for blob, digest in batch.values():
    
    403
    +                written_digests.append(self._send_blob(blob, digest=digest))
    
    404
    +
    
    405
    +        return written_digests

  • buildgrid/server/cas/storage/remote.py
    ... ... @@ -25,9 +25,13 @@ import logging
    25 25
     
    
    26 26
     import grpc
    
    27 27
     
    
    28
    -from buildgrid.utils import gen_fetch_blob, gen_write_request_blob
    
    28
    +from buildgrid.client.cas import upload
    
    29 29
     from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
    
    30 30
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
    
    31
    +from buildgrid._protos.google.rpc import code_pb2
    
    32
    +from buildgrid._protos.google.rpc import status_pb2
    
    33
    +from buildgrid.utils import gen_fetch_blob
    
    34
    +from buildgrid.settings import HASH
    
    31 35
     
    
    32 36
     from .storage_abc import StorageABC
    
    33 37
     
    
    ... ... @@ -36,7 +40,10 @@ class RemoteStorage(StorageABC):
    36 40
     
    
    37 41
         def __init__(self, channel, instance_name):
    
    38 42
             self.logger = logging.getLogger(__name__)
    
    39
    -        self._instance_name = instance_name
    
    43
    +
    
    44
    +        self.instance_name = instance_name
    
    45
    +        self.channel = channel
    
    46
    +
    
    40 47
             self._stub_bs = bytestream_pb2_grpc.ByteStreamStub(channel)
    
    41 48
             self._stub_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(channel)
    
    42 49
     
    
    ... ... @@ -50,16 +57,12 @@ class RemoteStorage(StorageABC):
    50 57
                 fetched_data = io.BytesIO()
    
    51 58
                 length = 0
    
    52 59
     
    
    53
    -            for data in gen_fetch_blob(self._stub_bs, digest, self._instance_name):
    
    60
    +            for data in gen_fetch_blob(self._stub_bs, digest, self.instance_name):
    
    54 61
                     length += fetched_data.write(data)
    
    55 62
     
    
    56
    -            if length:
    
    57
    -                assert digest.size_bytes == length
    
    58
    -                fetched_data.seek(0)
    
    59
    -                return fetched_data
    
    60
    -
    
    61
    -            else:
    
    62
    -                return None
    
    63
    +            assert digest.size_bytes == length
    
    64
    +            fetched_data.seek(0)
    
    65
    +            return fetched_data
    
    63 66
     
    
    64 67
             except grpc.RpcError as e:
    
    65 68
                 if e.code() == grpc.StatusCode.NOT_FOUND:
    
    ... ... @@ -71,16 +74,14 @@ class RemoteStorage(StorageABC):
    71 74
             return None
    
    72 75
     
    
    73 76
         def begin_write(self, digest):
    
    74
    -        return io.BytesIO(digest.SerializeToString())
    
    77
    +        return io.BytesIO()
    
    75 78
     
    
    76 79
         def commit_write(self, digest, write_session):
    
    77
    -        write_session.seek(0)
    
    78
    -
    
    79
    -        for request in gen_write_request_blob(write_session, digest, self._instance_name):
    
    80
    -            self._stub_bs.Write(request)
    
    80
    +        with upload(self.channel, instance=self.instance_name) as uploader:
    
    81
    +            uploader.put_blob(write_session.getvalue())
    
    81 82
     
    
    82 83
         def missing_blobs(self, blobs):
    
    83
    -        request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=self._instance_name)
    
    84
    +        request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=self.instance_name)
    
    84 85
     
    
    85 86
             for blob in blobs:
    
    86 87
                 request_digest = request.blob_digests.add()
    
    ... ... @@ -92,19 +93,15 @@ class RemoteStorage(StorageABC):
    92 93
             return [x for x in response.missing_blob_digests]
    
    93 94
     
    
    94 95
         def bulk_update_blobs(self, blobs):
    
    95
    -        request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self._instance_name)
    
    96
    -
    
    97
    -        for digest, data in blobs:
    
    98
    -            reqs = request.requests.add()
    
    99
    -            reqs.digest.CopyFrom(digest)
    
    100
    -            reqs.data = data
    
    101
    -
    
    102
    -        response = self._stub_cas.BatchUpdateBlobs(request)
    
    103
    -
    
    104
    -        responses = response.responses
    
    105
    -
    
    106
    -        # Check everything was sent back, even if order changed
    
    107
    -        assert ([x.digest for x in request.requests].sort(key=lambda x: x.hash)) == \
    
    108
    -            ([x.digest for x in responses].sort(key=lambda x: x.hash))
    
    109
    -
    
    110
    -        return [x.status for x in responses]
    96
    +        sent_digests = []
    
    97
    +        with upload(self.channel, instance=self.instance_name) as uploader:
    
    98
    +            for digest, blob in blobs:
    
    99
    +                if len(blob) != digest.size_bytes or HASH(blob).hexdigest() != digest.hash:
    
    100
    +                    sent_digests.append(remote_execution_pb2.Digest())
    
    101
    +                else:
    
    102
    +                    sent_digests.append(uploader.put_blob(blob, digest=digest, queue=True))
    
    103
    +
    
    104
    +        assert len(sent_digests) == len(blobs)
    
    105
    +
    
    106
    +        return [status_pb2.Status(code=code_pb2.OK) if d.ByteSize() > 0
    
    107
    +                else status_pb2.Status(code=code_pb2.UNKNOWN) for d in sent_digests]

  • buildgrid/utils.py
    ... ... @@ -15,7 +15,6 @@
    15 15
     
    
    16 16
     from operator import attrgetter
    
    17 17
     import os
    
    18
    -import uuid
    
    19 18
     
    
    20 19
     from buildgrid.settings import HASH
    
    21 20
     from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    ... ... @@ -34,32 +33,6 @@ def gen_fetch_blob(stub, digest, instance_name=""):
    34 33
             yield response.data
    
    35 34
     
    
    36 35
     
    
    37
    -def gen_write_request_blob(digest_bytes, digest, instance_name=""):
    
    38
    -    """ Generates a bytestream write request
    
    39
    -    """
    
    40
    -    resource_name = os.path.join(instance_name, 'uploads', str(uuid.uuid4()),
    
    41
    -                                 'blobs', digest.hash, str(digest.size_bytes))
    
    42
    -
    
    43
    -    offset = 0
    
    44
    -    finished = False
    
    45
    -    remaining = digest.size_bytes
    
    46
    -
    
    47
    -    while not finished:
    
    48
    -        chunk_size = min(remaining, 64 * 1024)
    
    49
    -        remaining -= chunk_size
    
    50
    -        finished = remaining <= 0
    
    51
    -
    
    52
    -        request = bytestream_pb2.WriteRequest()
    
    53
    -        request.resource_name = resource_name
    
    54
    -        request.write_offset = offset
    
    55
    -        request.data = digest_bytes.read(chunk_size)
    
    56
    -        request.finish_write = finished
    
    57
    -
    
    58
    -        yield request
    
    59
    -
    
    60
    -        offset += chunk_size
    
    61
    -
    
    62
    -
    
    63 36
     def write_fetch_directory(root_directory, stub, digest, instance_name=None):
    
    64 37
         """Locally replicates a directory from CAS.
    
    65 38
     
    
    ... ... @@ -137,250 +110,170 @@ def create_digest(bytes_to_digest):
    137 110
             bytes_to_digest (bytes): byte data to digest.
    
    138 111
     
    
    139 112
         Returns:
    
    140
    -        :obj:`Digest`: The gRPC :obj:`Digest` for the given byte data.
    
    113
    +        :obj:`Digest`: The :obj:`Digest` for the given byte data.
    
    141 114
         """
    
    142 115
         return remote_execution_pb2.Digest(hash=HASH(bytes_to_digest).hexdigest(),
    
    143 116
                                            size_bytes=len(bytes_to_digest))
    
    144 117
     
    
    145 118
     
    
    146
    -def merkle_maker(directory):
    
    147
    -    """ Walks thorugh given directory, yielding the binary and digest
    
    148
    -    """
    
    149
    -    directory_pb2 = remote_execution_pb2.Directory()
    
    150
    -    for (dir_path, dir_names, file_names) in os.walk(directory):
    
    151
    -
    
    152
    -        for file_name in file_names:
    
    153
    -            file_path = os.path.join(dir_path, file_name)
    
    154
    -            chunk = read_file(file_path)
    
    155
    -            file_digest = create_digest(chunk)
    
    156
    -            directory_pb2.files.extend([file_maker(file_path, file_digest)])
    
    157
    -            yield chunk, file_digest
    
    158
    -
    
    159
    -        for inner_dir in dir_names:
    
    160
    -            inner_dir_path = os.path.join(dir_path, inner_dir)
    
    161
    -            yield from merkle_maker(inner_dir_path)
    
    162
    -
    
    163
    -    directory_string = directory_pb2.SerializeToString()
    
    119
    +def read_file(file_path):
    
    120
    +    """Loads raw file content in memory.
    
    164 121
     
    
    165
    -    yield directory_string, create_digest(directory_string)
    
    122
    +    Args:
    
    123
    +        file_path (str): path to the target file.
    
    166 124
     
    
    125
    +    Returns:
    
    126
    +        bytes: Raw file's content until EOF.
    
    167 127
     
    
    168
    -def file_maker(file_path, file_digest):
    
    169
    -    """ Creates a File Node
    
    128
    +    Raises:
    
    129
    +        OSError: If `file_path` does not exist or is not readable.
    
    170 130
         """
    
    171
    -    _, file_name = os.path.split(file_path)
    
    172
    -    return remote_execution_pb2.FileNode(name=file_name,
    
    173
    -                                         digest=file_digest,
    
    174
    -                                         is_executable=os.access(file_path, os.X_OK))
    
    131
    +    with open(file_path, 'rb') as byte_file:
    
    132
    +        return byte_file.read()
    
    175 133
     
    
    176 134
     
    
    177
    -def directory_maker(directory_path, child_directories=None, cas=None, upload_directories=True):
    
    178
    -    """Creates a :obj:`Directory` from a local directory and possibly upload it.
    
    135
    +def write_file(file_path, content):
    
    136
    +    """Dumps raw memory content to a file.
    
    179 137
     
    
    180 138
         Args:
    
    181
    -        directory_path (str): absolute or relative path to a local directory.
    
    182
    -        child_directories (list): output list of of children :obj:`Directory`
    
    183
    -            objects.
    
    184
    -        cas (:obj:`Uploader`): a CAS client uploader.
    
    185
    -        upload_directories (bool): wheter or not to upload the :obj:`Directory`
    
    186
    -            objects along with the files.
    
    139
    +        file_path (str): path to the target file.
    
    140
    +        content (bytes): raw file's content.
    
    187 141
     
    
    188
    -    Returns:
    
    189
    -        :obj:`Directory`, :obj:`Digest`: Tuple of a new gRPC :obj:`Directory`
    
    190
    -        for the local directory pointed by `directory_path` and the digest
    
    191
    -        for that object.
    
    142
    +    Raises:
    
    143
    +        OSError: If `file_path` does not exist or is not writable.
    
    192 144
         """
    
    193
    -    if not os.path.isabs(directory_path):
    
    194
    -        directory_path = os.path.abspath(directory_path)
    
    195
    -
    
    196
    -    files, directories, symlinks = list(), list(), list()
    
    197
    -    for directory_entry in os.scandir(directory_path):
    
    198
    -        # Create a FileNode and corresponding BatchUpdateBlobsRequest:
    
    199
    -        if directory_entry.is_file(follow_symlinks=False):
    
    200
    -            if cas is not None:
    
    201
    -                node_digest = cas.upload_file(directory_entry.path)
    
    202
    -            else:
    
    203
    -                node_digest = create_digest(read_file(directory_entry.path))
    
    204
    -
    
    205
    -            node = remote_execution_pb2.FileNode()
    
    206
    -            node.name = directory_entry.name
    
    207
    -            node.digest.CopyFrom(node_digest)
    
    208
    -            node.is_executable = os.access(directory_entry.path, os.X_OK)
    
    209
    -
    
    210
    -            files.append(node)
    
    211
    -
    
    212
    -        # Create a DirectoryNode and corresponding BatchUpdateBlobsRequest:
    
    213
    -        elif directory_entry.is_dir(follow_symlinks=False):
    
    214
    -            _, node_digest = directory_maker(directory_entry.path,
    
    215
    -                                             child_directories=child_directories,
    
    216
    -                                             upload_directories=upload_directories,
    
    217
    -                                             cas=cas)
    
    218
    -
    
    219
    -            node = remote_execution_pb2.DirectoryNode()
    
    220
    -            node.name = directory_entry.name
    
    221
    -            node.digest.CopyFrom(node_digest)
    
    222
    -
    
    223
    -            directories.append(node)
    
    224
    -
    
    225
    -        # Create a SymlinkNode if necessary;
    
    226
    -        elif os.path.islink(directory_entry.path):
    
    227
    -            node_target = os.readlink(directory_entry.path)
    
    145
    +    with open(file_path, 'wb') as byte_file:
    
    146
    +        byte_file.write(content)
    
    147
    +        byte_file.flush()
    
    228 148
     
    
    229
    -            node = remote_execution_pb2.SymlinkNode()
    
    230
    -            node.name = directory_entry.name
    
    231
    -            node.target = node_target
    
    232 149
     
    
    233
    -            symlinks.append(node)
    
    150
    +def merkle_tree_maker(directory_path):
    
    151
    +    """Walks a local folder tree, generating :obj:`FileNode` and
    
    152
    +    :obj:`DirectoryNode`.
    
    234 153
     
    
    235
    -    files.sort(key=attrgetter('name'))
    
    236
    -    directories.sort(key=attrgetter('name'))
    
    237
    -    symlinks.sort(key=attrgetter('name'))
    
    154
    +    Args:
    
    155
    +        directory_path (str): absolute or relative path to a local directory.
    
    238 156
     
    
    239
    -    directory = remote_execution_pb2.Directory()
    
    240
    -    directory.files.extend(files)
    
    241
    -    directory.directories.extend(directories)
    
    242
    -    directory.symlinks.extend(symlinks)
    
    157
    +    Yields:
    
    158
    +        :obj:`Message`, bytes, str: a tutple of either a :obj:`FileNode` or
    
    159
    +        :obj:`DirectoryNode` message, the corresponding blob and the
    
    160
    +        corresponding node path.
    
    161
    +    """
    
    162
    +    directory_name = os.path.basename(directory_path)
    
    243 163
     
    
    244
    -    if child_directories is not None:
    
    245
    -        child_directories.append(directory)
    
    164
    +    # Actual generator, yields recursively FileNodes and DirectoryNodes:
    
    165
    +    def __merkle_tree_maker(directory_path, directory_name):
    
    166
    +        if not os.path.isabs(directory_path):
    
    167
    +            directory_path = os.path.abspath(directory_path)
    
    246 168
     
    
    247
    -    if cas is not None and upload_directories:
    
    248
    -        directory_digest = cas.upload_directory(directory)
    
    249
    -    else:
    
    250
    -        directory_digest = create_digest(directory.SerializeToString())
    
    169
    +        directory = remote_execution_pb2.Directory()
    
    251 170
     
    
    252
    -    return directory, directory_digest
    
    171
    +        files, directories, symlinks = [], [], []
    
    172
    +        for directory_entry in os.scandir(directory_path):
    
    173
    +            node_name, node_path = directory_entry.name, directory_entry.path
    
    253 174
     
    
    175
    +            if directory_entry.is_file(follow_symlinks=False):
    
    176
    +                node_blob = read_file(directory_entry.path)
    
    177
    +                node_digest = create_digest(node_blob)
    
    254 178
     
    
    255
    -def tree_maker(directory_path, cas=None):
    
    256
    -    """Creates a :obj:`Tree` from a local directory and possibly upload it.
    
    179
    +                node = remote_execution_pb2.FileNode()
    
    180
    +                node.name = node_name
    
    181
    +                node.digest.CopyFrom(node_digest)
    
    182
    +                node.is_executable = os.access(node_path, os.X_OK)
    
    257 183
     
    
    258
    -    If `cas` is specified, the local directory content will be uploded/stored
    
    259
    -    in remote CAS (the :obj:`Tree` message won't).
    
    184
    +                files.append(node)
    
    260 185
     
    
    261
    -    Args:
    
    262
    -        directory_path (str): absolute or relative path to a local directory.
    
    263
    -        cas (:obj:`Uploader`): a CAS client uploader.
    
    186
    +                yield node, node_blob, node_path
    
    264 187
     
    
    265
    -    Returns:
    
    266
    -        :obj:`Tree`, :obj:`Digest`: Tuple of a new gRPC :obj:`Tree` for the
    
    267
    -        local directory pointed by `directory_path` and the digest for that
    
    268
    -        object.
    
    269
    -    """
    
    270
    -    if not os.path.isabs(directory_path):
    
    271
    -        directory_path = os.path.abspath(directory_path)
    
    188
    +            elif directory_entry.is_dir(follow_symlinks=False):
    
    189
    +                node, node_blob, _ = yield from __merkle_tree_maker(node_path, node_name)
    
    272 190
     
    
    273
    -    child_directories = list()
    
    274
    -    directory, _ = directory_maker(directory_path,
    
    275
    -                                   child_directories=child_directories,
    
    276
    -                                   upload_directories=False,
    
    277
    -                                   cas=cas)
    
    191
    +                directories.append(node)
    
    278 192
     
    
    279
    -    tree = remote_execution_pb2.Tree()
    
    280
    -    tree.children.extend(child_directories)
    
    281
    -    tree.root.CopyFrom(directory)
    
    193
    +                yield node, node_blob, node_path
    
    282 194
     
    
    283
    -    if cas is not None:
    
    284
    -        tree_digest = cas.send_message(tree)
    
    285
    -    else:
    
    286
    -        tree_digest = create_digest(tree.SerializeToString())
    
    195
    +            # Create a SymlinkNode;
    
    196
    +            elif os.path.islink(directory_entry.path):
    
    197
    +                node_target = os.readlink(directory_entry.path)
    
    287 198
     
    
    288
    -    return tree, tree_digest
    
    199
    +                node = remote_execution_pb2.SymlinkNode()
    
    200
    +                node.name = directory_entry.name
    
    201
    +                node.target = node_target
    
    289 202
     
    
    203
    +                symlinks.append(node)
    
    290 204
     
    
    291
    -def read_file(file_path):
    
    292
    -    """Loads raw file content in memory.
    
    205
    +        files.sort(key=attrgetter('name'))
    
    206
    +        directories.sort(key=attrgetter('name'))
    
    207
    +        symlinks.sort(key=attrgetter('name'))
    
    293 208
     
    
    294
    -    Args:
    
    295
    -        file_path (str): path to the target file.
    
    209
    +        directory.files.extend(files)
    
    210
    +        directory.directories.extend(directories)
    
    211
    +        directory.symlinks.extend(symlinks)
    
    296 212
     
    
    297
    -    Returns:
    
    298
    -        bytes: Raw file's content until EOF.
    
    213
    +        node_blob = directory.SerializeToString()
    
    214
    +        node_digest = create_digest(node_blob)
    
    299 215
     
    
    300
    -    Raises:
    
    301
    -        OSError: If `file_path` does not exist or is not readable.
    
    302
    -    """
    
    303
    -    with open(file_path, 'rb') as byte_file:
    
    304
    -        return byte_file.read()
    
    216
    +        node = remote_execution_pb2.DirectoryNode()
    
    217
    +        node.name = directory_name
    
    218
    +        node.digest.CopyFrom(node_digest)
    
    305 219
     
    
    220
    +        return node, node_blob, directory_path
    
    306 221
     
    
    307
    -def write_file(file_path, content):
    
    308
    -    """Dumps raw memory content to a file.
    
    222
    +    node, node_blob, node_path = yield from __merkle_tree_maker(directory_path,
    
    223
    +                                                                directory_name)
    
    309 224
     
    
    310
    -    Args:
    
    311
    -        file_path (str): path to the target file.
    
    312
    -        content (bytes): raw file's content.
    
    313
    -
    
    314
    -    Raises:
    
    315
    -        OSError: If `file_path` does not exist or is not writable.
    
    316
    -    """
    
    317
    -    with open(file_path, 'wb') as byte_file:
    
    318
    -        byte_file.write(content)
    
    319
    -        byte_file.flush()
    
    225
    +    yield node, node_blob, node_path
    
    320 226
     
    
    321 227
     
    
    322
    -def output_file_maker(file_path, input_path, cas=None):
    
    228
    +def output_file_maker(file_path, input_path, file_digest):
    
    323 229
         """Creates an :obj:`OutputFile` from a local file and possibly upload it.
    
    324 230
     
    
    325
    -    If `cas` is specified, the local file will be uploded/stored in remote CAS
    
    326
    -    (the :obj:`OutputFile` message won't).
    
    327
    -
    
    328 231
         Note:
    
    329 232
             `file_path` **must** point inside or be relative to `input_path`.
    
    330 233
     
    
    331 234
         Args:
    
    332 235
             file_path (str): absolute or relative path to a local file.
    
    333 236
             input_path (str): absolute or relative path to the input root directory.
    
    334
    -        cas (:obj:`Uploader`): a CAS client uploader.
    
    237
    +        file_digest (:obj:`Digest`): the underlying file's digest.
    
    335 238
     
    
    336 239
         Returns:
    
    337
    -        :obj:`OutputFile`: a new gRPC :obj:`OutputFile` object for the file
    
    338
    -        pointed by `file_path`.
    
    240
    +        :obj:`OutputFile`: a new :obj:`OutputFile` object for the file pointed
    
    241
    +        by `file_path`.
    
    339 242
         """
    
    340 243
         if not os.path.isabs(file_path):
    
    341 244
             file_path = os.path.abspath(file_path)
    
    342 245
         if not os.path.isabs(input_path):
    
    343 246
             input_path = os.path.abspath(input_path)
    
    344 247
     
    
    345
    -    if cas is not None:
    
    346
    -        file_digest = cas.upload_file(file_path)
    
    347
    -    else:
    
    348
    -        file_digest = create_digest(read_file(file_path))
    
    349
    -
    
    350 248
         output_file = remote_execution_pb2.OutputFile()
    
    351 249
         output_file.digest.CopyFrom(file_digest)
    
    352
    -    # OutputFile.path should be relative to the working direcory:
    
    250
    +    # OutputFile.path should be relative to the working directory
    
    353 251
         output_file.path = os.path.relpath(file_path, start=input_path)
    
    354 252
         output_file.is_executable = os.access(file_path, os.X_OK)
    
    355 253
     
    
    356 254
         return output_file
    
    357 255
     
    
    358 256
     
    
    359
    -def output_directory_maker(directory_path, working_path, cas=None):
    
    257
    +def output_directory_maker(directory_path, working_path, tree_digest):
    
    360 258
         """Creates an :obj:`OutputDirectory` from a local directory.
    
    361 259
     
    
    362
    -    If `cas` is specified, the local directory content will be uploded/stored
    
    363
    -    in remote CAS (the :obj:`OutputDirectory` message won't).
    
    364
    -
    
    365 260
         Note:
    
    366 261
             `directory_path` **must** point inside or be relative to `input_path`.
    
    367 262
     
    
    368 263
         Args:
    
    369 264
             directory_path (str): absolute or relative path to a local directory.
    
    370 265
             working_path (str): absolute or relative path to the working directory.
    
    371
    -        cas (:obj:`Uploader`): a CAS client uploader.
    
    266
    +        tree_digest (:obj:`Digest`): the underlying folder tree's digest.
    
    372 267
     
    
    373 268
         Returns:
    
    374
    -        :obj:`OutputDirectory`: a new gRPC :obj:`OutputDirectory` for the
    
    375
    -        directory pointed by `directory_path`.
    
    269
    +        :obj:`OutputDirectory`: a new :obj:`OutputDirectory` for the directory
    
    270
    +        pointed by `directory_path`.
    
    376 271
         """
    
    377 272
         if not os.path.isabs(directory_path):
    
    378 273
             directory_path = os.path.abspath(directory_path)
    
    379 274
         if not os.path.isabs(working_path):
    
    380 275
             working_path = os.path.abspath(working_path)
    
    381 276
     
    
    382
    -    _, tree_digest = tree_maker(directory_path, cas=cas)
    
    383
    -
    
    384 277
         output_directory = remote_execution_pb2.OutputDirectory()
    
    385 278
         output_directory.tree_digest.CopyFrom(tree_digest)
    
    386 279
         output_directory.path = os.path.relpath(directory_path, start=working_path)
    

  • setup.py
    ... ... @@ -89,6 +89,7 @@ tests_require = [
    89 89
         'coverage == 4.4.0',
    
    90 90
         'moto',
    
    91 91
         'pep8',
    
    92
    +    'psutil',
    
    92 93
         'pytest == 3.6.4',
    
    93 94
         'pytest-cov >= 2.6.0',
    
    94 95
         'pytest-pep8',
    

  • tests/cas/data/hello.cc
    1
    +#include <iostream>
    
    2
    +
    
    3
    +int main()
    
    4
    +{
    
    5
    +  std::cout << "Hello, World!" << std::endl;
    
    6
    +  return 0;
    
    7
    +}

  • tests/cas/data/hello/hello.c
    1
    +#include <stdio.h>
    
    2
    +
    
    3
    +#include "hello.h"
    
    4
    +
    
    5
    +int main()
    
    6
    +{
    
    7
    +  printf("%s\n", HELLO_WORLD);
    
    8
    +  return 0;
    
    9
    +}

  • tests/cas/data/hello/hello.h
    1
    +#define HELLO_WORLD "Hello, World!"

  • tests/cas/data/void

  • tests/cas/test_client.py
    1
    +# Copyright (C) 2018 Bloomberg LP
    
    2
    +#
    
    3
    +# Licensed under the Apache License, Version 2.0 (the "License");
    
    4
    +# you may not use this file except in compliance with the License.
    
    5
    +# You may obtain a copy of the License at
    
    6
    +#
    
    7
    +#  <http://www.apache.org/licenses/LICENSE-2.0>
    
    8
    +#
    
    9
    +# Unless required by applicable law or agreed to in writing, software
    
    10
    +# distributed under the License is distributed on an "AS IS" BASIS,
    
    11
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    
    12
    +# See the License for the specific language governing permissions and
    
    13
    +# limitations under the License.
    
    14
    +
    
    15
    +# pylint: disable=redefined-outer-name
    
    16
    +
    
    17
    +import os
    
    18
    +
    
    19
    +import grpc
    
    20
    +import pytest
    
    21
    +
    
    22
    +from buildgrid.client.cas import upload
    
    23
    +from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    24
    +from buildgrid.utils import create_digest
    
    25
    +
    
    26
    +from ..utils.cas import serve_cas, run_in_subprocess
    
    27
    +
    
    28
    +
    
    29
    +INTANCES = ['', 'instance']
    
    30
    +BLOBS = [(b'',), (b'test-string',), (b'test', b'string')]
    
    31
    +MESSAGES = [
    
    32
    +    (remote_execution_pb2.Directory(),),
    
    33
    +    (remote_execution_pb2.SymlinkNode(name='name', target='target'),),
    
    34
    +    (remote_execution_pb2.Action(do_not_cache=True),
    
    35
    +     remote_execution_pb2.ActionResult(exit_code=12))
    
    36
    +]
    
    37
    +DATA_DIR = os.path.join(
    
    38
    +    os.path.dirname(os.path.realpath(__file__)), 'data')
    
    39
    +FILES = [
    
    40
    +    (os.path.join(DATA_DIR, 'void'),),
    
    41
    +    (os.path.join(DATA_DIR, 'hello.cc'),),
    
    42
    +    (os.path.join(DATA_DIR, 'hello', 'hello.c'),
    
    43
    +     os.path.join(DATA_DIR, 'hello', 'hello.h'))]
    
    44
    +DIRECTORIES = [
    
    45
    +    (os.path.join(DATA_DIR, 'hello'),),
    
    46
    +    (os.path.join(DATA_DIR, 'hello'), DATA_DIR)]
    
    47
    +
    
    48
    +
    
    49
    +@pytest.mark.parametrize('blobs', BLOBS)
    
    50
    +@pytest.mark.parametrize('instance', INTANCES)
    
    51
    +def test_upload_blob(instance, blobs):
    
    52
    +    # Actual test function, to be run in a subprocess:
    
    53
    +    def __test_upload_blob(queue, remote, instance, blobs):
    
    54
    +        # Open a channel to the remote CAS server:
    
    55
    +        channel = grpc.insecure_channel(remote)
    
    56
    +
    
    57
    +        digests = []
    
    58
    +        with upload(channel, instance) as uploader:
    
    59
    +            if len(blobs) > 1:
    
    60
    +                for blob in blobs:
    
    61
    +                    digest = uploader.put_blob(blob, queue=True)
    
    62
    +                    digests.append(digest.SerializeToString())
    
    63
    +            else:
    
    64
    +                digest = uploader.put_blob(blobs[0], queue=False)
    
    65
    +                digests.append(digest.SerializeToString())
    
    66
    +
    
    67
    +        queue.put(digests)
    
    68
    +
    
    69
    +    # Start a minimal CAS server in a subprocess:
    
    70
    +    with serve_cas([instance]) as server:
    
    71
    +        digests = run_in_subprocess(__test_upload_blob,
    
    72
    +                                    server.remote, instance, blobs)
    
    73
    +
    
    74
    +        for blob, digest_blob in zip(blobs, digests):
    
    75
    +            digest = remote_execution_pb2.Digest()
    
    76
    +            digest.ParseFromString(digest_blob)
    
    77
    +
    
    78
    +            assert server.has(digest)
    
    79
    +            assert server.compare_blobs(digest, blob)
    
    80
    +
    
    81
    +
    
    82
    +@pytest.mark.parametrize('messages', MESSAGES)
    
    83
    +@pytest.mark.parametrize('instance', INTANCES)
    
    84
    +def test_upload_message(instance, messages):
    
    85
    +    # Actual test function, to be run in a subprocess:
    
    86
    +    def __test_upload_message(queue, remote, instance, messages):
    
    87
    +        # Open a channel to the remote CAS server:
    
    88
    +        channel = grpc.insecure_channel(remote)
    
    89
    +
    
    90
    +        digests = []
    
    91
    +        with upload(channel, instance) as uploader:
    
    92
    +            if len(messages) > 1:
    
    93
    +                for message in messages:
    
    94
    +                    digest = uploader.put_message(message, queue=True)
    
    95
    +                    digests.append(digest.SerializeToString())
    
    96
    +            else:
    
    97
    +                digest = uploader.put_message(messages[0], queue=False)
    
    98
    +                digests.append(digest.SerializeToString())
    
    99
    +
    
    100
    +        queue.put(digests)
    
    101
    +
    
    102
    +    # Start a minimal CAS server in a subprocess:
    
    103
    +    with serve_cas([instance]) as server:
    
    104
    +        digests = run_in_subprocess(__test_upload_message,
    
    105
    +                                    server.remote, instance, messages)
    
    106
    +
    
    107
    +        for message, digest_blob in zip(messages, digests):
    
    108
    +            digest = remote_execution_pb2.Digest()
    
    109
    +            digest.ParseFromString(digest_blob)
    
    110
    +
    
    111
    +            assert server.has(digest)
    
    112
    +            assert server.compare_messages(digest, message)
    
    113
    +
    
    114
    +
    
    115
    +@pytest.mark.parametrize('file_paths', FILES)
    
    116
    +@pytest.mark.parametrize('instance', INTANCES)
    
    117
    +def test_upload_file(instance, file_paths):
    
    118
    +    # Actual test function, to be run in a subprocess:
    
    119
    +    def __test_upload_file(queue, remote, instance, file_paths):
    
    120
    +        # Open a channel to the remote CAS server:
    
    121
    +        channel = grpc.insecure_channel(remote)
    
    122
    +
    
    123
    +        digests = []
    
    124
    +        with upload(channel, instance) as uploader:
    
    125
    +            if len(file_paths) > 1:
    
    126
    +                for file_path in file_paths:
    
    127
    +                    digest = uploader.upload_file(file_path, queue=True)
    
    128
    +                    digests.append(digest.SerializeToString())
    
    129
    +            else:
    
    130
    +                digest = uploader.upload_file(file_paths[0], queue=False)
    
    131
    +                digests.append(digest.SerializeToString())
    
    132
    +
    
    133
    +        queue.put(digests)
    
    134
    +
    
    135
    +    # Start a minimal CAS server in a subprocess:
    
    136
    +    with serve_cas([instance]) as server:
    
    137
    +        digests = run_in_subprocess(__test_upload_file,
    
    138
    +                                    server.remote, instance, file_paths)
    
    139
    +
    
    140
    +        for file_path, digest_blob in zip(file_paths, digests):
    
    141
    +            digest = remote_execution_pb2.Digest()
    
    142
    +            digest.ParseFromString(digest_blob)
    
    143
    +
    
    144
    +            assert server.has(digest)
    
    145
    +            assert server.compare_files(digest, file_path)
    
    146
    +
    
    147
    +
    
    148
    +@pytest.mark.parametrize('directory_paths', DIRECTORIES)
    
    149
    +@pytest.mark.parametrize('instance', INTANCES)
    
    150
    +def test_upload_directory(instance, directory_paths):
    
    151
    +    # Actual test function, to be run in a subprocess:
    
    152
    +    def __test_upload_directory(queue, remote, instance, directory_paths):
    
    153
    +        # Open a channel to the remote CAS server:
    
    154
    +        channel = grpc.insecure_channel(remote)
    
    155
    +
    
    156
    +        digests = []
    
    157
    +        with upload(channel, instance) as uploader:
    
    158
    +            if len(directory_paths) > 1:
    
    159
    +                for directory_path in directory_paths:
    
    160
    +                    digest = uploader.upload_directory(directory_path, queue=True)
    
    161
    +                    digests.append(digest.SerializeToString())
    
    162
    +            else:
    
    163
    +                digest = uploader.upload_directory(directory_paths[0], queue=False)
    
    164
    +                digests.append(digest.SerializeToString())
    
    165
    +
    
    166
    +        queue.put(digests)
    
    167
    +
    
    168
    +    # Start a minimal CAS server in a subprocess:
    
    169
    +    with serve_cas([instance]) as server:
    
    170
    +        digests = run_in_subprocess(__test_upload_directory,
    
    171
    +                                    server.remote, instance, directory_paths)
    
    172
    +
    
    173
    +        for directory_path, digest_blob in zip(directory_paths, digests):
    
    174
    +            digest = remote_execution_pb2.Digest()
    
    175
    +            digest.ParseFromString(digest_blob)
    
    176
    +
    
    177
    +            assert server.compare_directories(digest, directory_path)
    
    178
    +
    
    179
    +
    
    180
    +@pytest.mark.parametrize('directory_paths', DIRECTORIES)
    
    181
    +@pytest.mark.parametrize('instance', INTANCES)
    
    182
    +def test_upload_tree(instance, directory_paths):
    
    183
    +    # Actual test function, to be run in a subprocess:
    
    184
    +    def __test_upload_tree(queue, remote, instance, directory_paths):
    
    185
    +        # Open a channel to the remote CAS server:
    
    186
    +        channel = grpc.insecure_channel(remote)
    
    187
    +
    
    188
    +        digests = []
    
    189
    +        with upload(channel, instance) as uploader:
    
    190
    +            if len(directory_paths) > 1:
    
    191
    +                for directory_path in directory_paths:
    
    192
    +                    digest = uploader.upload_tree(directory_path, queue=True)
    
    193
    +                    digests.append(digest.SerializeToString())
    
    194
    +            else:
    
    195
    +                digest = uploader.upload_tree(directory_paths[0], queue=False)
    
    196
    +                digests.append(digest.SerializeToString())
    
    197
    +
    
    198
    +        queue.put(digests)
    
    199
    +
    
    200
    +    # Start a minimal CAS server in a subprocess:
    
    201
    +    with serve_cas([instance]) as server:
    
    202
    +        digests = run_in_subprocess(__test_upload_tree,
    
    203
    +                                    server.remote, instance, directory_paths)
    
    204
    +
    
    205
    +        for directory_path, digest_blob in zip(directory_paths, digests):
    
    206
    +            digest = remote_execution_pb2.Digest()
    
    207
    +            digest.ParseFromString(digest_blob)
    
    208
    +
    
    209
    +            assert server.has(digest)
    
    210
    +
    
    211
    +            tree = remote_execution_pb2.Tree()
    
    212
    +            tree.ParseFromString(server.get(digest))
    
    213
    +
    
    214
    +            directory_digest = create_digest(tree.root.SerializeToString())
    
    215
    +
    
    216
    +            assert server.compare_directories(directory_digest, directory_path)

  • tests/cas/test_storage.py
    ... ... @@ -19,220 +19,285 @@
    19 19
     
    
    20 20
     import tempfile
    
    21 21
     
    
    22
    -from unittest import mock
    
    23
    -
    
    24 22
     import boto3
    
    25 23
     import grpc
    
    26
    -from grpc._server import _Context
    
    27 24
     import pytest
    
    28 25
     from moto import mock_s3
    
    29 26
     
    
    30
    -from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest
    
    31
    -from buildgrid.server.cas import service
    
    32
    -from buildgrid.server.cas.instance import ByteStreamInstance, ContentAddressableStorageInstance
    
    33
    -from buildgrid.server.cas.storage import remote
    
    27
    +from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    28
    +from buildgrid.server.cas.storage.remote import RemoteStorage
    
    34 29
     from buildgrid.server.cas.storage.lru_memory_cache import LRUMemoryCache
    
    35 30
     from buildgrid.server.cas.storage.disk import DiskStorage
    
    36 31
     from buildgrid.server.cas.storage.s3 import S3Storage
    
    37 32
     from buildgrid.server.cas.storage.with_cache import WithCacheStorage
    
    38 33
     from buildgrid.settings import HASH
    
    39 34
     
    
    35
    +from ..utils.cas import serve_cas, run_in_subprocess
    
    40 36
     
    
    41
    -context = mock.create_autospec(_Context)
    
    42
    -server = mock.create_autospec(grpc.server)
    
    43
    -
    
    44
    -abc = b"abc"
    
    45
    -abc_digest = Digest(hash=HASH(abc).hexdigest(), size_bytes=3)
    
    46
    -defg = b"defg"
    
    47
    -defg_digest = Digest(hash=HASH(defg).hexdigest(), size_bytes=4)
    
    48
    -hijk = b"hijk"
    
    49
    -hijk_digest = Digest(hash=HASH(hijk).hexdigest(), size_bytes=4)
    
    50
    -
    
    51
    -
    
    52
    -def write(storage, digest, blob):
    
    53
    -    session = storage.begin_write(digest)
    
    54
    -    session.write(blob)
    
    55
    -    storage.commit_write(digest, session)
    
    56
    -
    
    57
    -
    
    58
    -class MockCASStorage(ByteStreamInstance, ContentAddressableStorageInstance):
    
    59
    -
    
    60
    -    def __init__(self):
    
    61
    -        storage = LRUMemoryCache(256)
    
    62
    -        super().__init__(storage)
    
    63
    -
    
    64
    -
    
    65
    -# Mock a CAS server with LRUStorage to return "calls" made to it
    
    66
    -class MockStubServer:
    
    67
    -
    
    68
    -    def __init__(self):
    
    69
    -        instances = {"": MockCASStorage(), "dna": MockCASStorage()}
    
    70
    -        self._requests = []
    
    71
    -        with mock.patch.object(service, 'bytestream_pb2_grpc'):
    
    72
    -            self._bs_service = service.ByteStreamService(server)
    
    73
    -            for k, v in instances.items():
    
    74
    -                self._bs_service.add_instance(k, v)
    
    75
    -        with mock.patch.object(service, 'remote_execution_pb2_grpc'):
    
    76
    -            self._cas_service = service.ContentAddressableStorageService(server)
    
    77
    -            for k, v in instances.items():
    
    78
    -                self._cas_service.add_instance(k, v)
    
    79
    -
    
    80
    -    def Read(self, request):
    
    81
    -        yield from self._bs_service.Read(request, context)
    
    82
    -
    
    83
    -    def Write(self, request):
    
    84
    -        self._requests.append(request)
    
    85
    -        if request.finish_write:
    
    86
    -            response = self._bs_service.Write(self._requests, context)
    
    87
    -            self._requests = []
    
    88
    -            return response
    
    89
    -
    
    90
    -        return None
    
    91
    -
    
    92
    -    def FindMissingBlobs(self, request):
    
    93
    -        return self._cas_service.FindMissingBlobs(request, context)
    
    94
    -
    
    95
    -    def BatchUpdateBlobs(self, request):
    
    96
    -        return self._cas_service.BatchUpdateBlobs(request, context)
    
    97 37
     
    
    38
    +BLOBS = [(b'abc', b'defg', b'hijk', b'')]
    
    39
    +BLOBS_DIGESTS = [tuple([remote_execution_pb2.Digest(hash=HASH(blob).hexdigest(),
    
    40
    +                                                    size_bytes=len(blob)) for blob in blobs])
    
    41
    +                 for blobs in BLOBS]
    
    98 42
     
    
    99
    -# Instances of MockCASStorage
    
    100
    -@pytest.fixture(params=["", "dna"])
    
    101
    -def instance(params):
    
    102
    -    return {params, MockCASStorage()}
    
    103 43
     
    
    104
    -
    
    105
    -# General tests for all storage providers
    
    106
    -
    
    107
    -
    
    108
    -@pytest.fixture(params=["lru", "disk", "s3", "lru_disk", "disk_s3", "remote"])
    
    44
    +@pytest.fixture(params=['lru', 'disk', 's3', 'lru_disk', 'disk_s3', 'remote'])
    
    109 45
     def any_storage(request):
    
    110
    -    if request.param == "lru":
    
    46
    +    if request.param == 'lru':
    
    111 47
             yield LRUMemoryCache(256)
    
    112
    -    elif request.param == "disk":
    
    48
    +    elif request.param == 'disk':
    
    113 49
             with tempfile.TemporaryDirectory() as path:
    
    114 50
                 yield DiskStorage(path)
    
    115
    -    elif request.param == "s3":
    
    51
    +    elif request.param == 's3':
    
    116 52
             with mock_s3():
    
    117
    -            boto3.resource('s3').create_bucket(Bucket="testing")
    
    118
    -            yield S3Storage("testing")
    
    119
    -    elif request.param == "lru_disk":
    
    53
    +            boto3.resource('s3').create_bucket(Bucket='testing')
    
    54
    +            yield S3Storage('testing')
    
    55
    +    elif request.param == 'lru_disk':
    
    120 56
             # LRU cache with a uselessly small limit, so requests always fall back
    
    121 57
             with tempfile.TemporaryDirectory() as path:
    
    122 58
                 yield WithCacheStorage(LRUMemoryCache(1), DiskStorage(path))
    
    123
    -    elif request.param == "disk_s3":
    
    59
    +    elif request.param == 'disk_s3':
    
    124 60
             # Disk-based cache of S3, but we don't delete files, so requests
    
    125 61
             # are always handled by the cache
    
    126 62
             with tempfile.TemporaryDirectory() as path:
    
    127 63
                 with mock_s3():
    
    128
    -                boto3.resource('s3').create_bucket(Bucket="testing")
    
    129
    -                yield WithCacheStorage(DiskStorage(path), S3Storage("testing"))
    
    130
    -    elif request.param == "remote":
    
    131
    -        with mock.patch.object(remote, 'bytestream_pb2_grpc'):
    
    132
    -            with mock.patch.object(remote, 'remote_execution_pb2_grpc'):
    
    133
    -                mock_server = MockStubServer()
    
    134
    -                storage = remote.RemoteStorage(None, "")
    
    135
    -                storage._stub_bs = mock_server
    
    136
    -                storage._stub_cas = mock_server
    
    137
    -                yield storage
    
    138
    -
    
    139
    -
    
    140
    -def test_initially_empty(any_storage):
    
    141
    -    assert not any_storage.has_blob(abc_digest)
    
    142
    -    assert not any_storage.has_blob(defg_digest)
    
    143
    -    assert not any_storage.has_blob(hijk_digest)
    
    144
    -
    
    145
    -
    
    146
    -def test_basic_write_read(any_storage):
    
    147
    -    assert not any_storage.has_blob(abc_digest)
    
    148
    -    write(any_storage, abc_digest, abc)
    
    149
    -    assert any_storage.has_blob(abc_digest)
    
    150
    -    assert any_storage.get_blob(abc_digest).read() == abc
    
    151
    -
    
    152
    -    # Try writing the same digest again (since it's valid to do that)
    
    153
    -    write(any_storage, abc_digest, abc)
    
    154
    -    assert any_storage.has_blob(abc_digest)
    
    155
    -    assert any_storage.get_blob(abc_digest).read() == abc
    
    156
    -
    
    157
    -
    
    158
    -def test_bulk_write_read(any_storage):
    
    159
    -    missing_digests = any_storage.missing_blobs([abc_digest, defg_digest, hijk_digest])
    
    160
    -    assert len(missing_digests) == 3
    
    161
    -    assert abc_digest in missing_digests
    
    162
    -    assert defg_digest in missing_digests
    
    163
    -    assert hijk_digest in missing_digests
    
    64
    +                boto3.resource('s3').create_bucket(Bucket='testing')
    
    65
    +                yield WithCacheStorage(DiskStorage(path), S3Storage('testing'))
    
    66
    +    elif request.param == 'remote':
    
    67
    +        with serve_cas(['testing']) as server:
    
    68
    +            yield server.remote
    
    164 69
     
    
    165
    -    bulk_update_results = any_storage.bulk_update_blobs([(abc_digest, abc), (defg_digest, defg),
    
    166
    -                                                         (hijk_digest, b'????')])
    
    167
    -    assert len(bulk_update_results) == 3
    
    168
    -    assert bulk_update_results[0].code == 0
    
    169
    -    assert bulk_update_results[1].code == 0
    
    170
    -    assert bulk_update_results[2].code != 0
    
    171
    -
    
    172
    -    missing_digests = any_storage.missing_blobs([abc_digest, defg_digest, hijk_digest])
    
    173
    -    assert missing_digests == [hijk_digest]
    
    174
    -
    
    175
    -    assert any_storage.get_blob(abc_digest).read() == abc
    
    176
    -    assert any_storage.get_blob(defg_digest).read() == defg
    
    177
    -
    
    178
    -
    
    179
    -def test_nonexistent_read(any_storage):
    
    180
    -    assert any_storage.get_blob(abc_digest) is None
    
    181 70
     
    
    71
    +def write(storage, digest, blob):
    
    72
    +    session = storage.begin_write(digest)
    
    73
    +    session.write(blob)
    
    74
    +    storage.commit_write(digest, session)
    
    182 75
     
    
    183
    -# Tests for special behavior of individual storage providers
    
    184 76
     
    
    77
    +@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
    
    78
    +def test_initially_empty(any_storage, blobs_digests):
    
    79
    +    _, digests = blobs_digests
    
    80
    +
    
    81
    +    # Actual test function, failing on assertions:
    
    82
    +    def __test_initially_empty(any_storage, digests):
    
    83
    +        for digest in digests:
    
    84
    +            assert not any_storage.has_blob(digest)
    
    85
    +
    
    86
    +    # Helper test function for remote storage, to be run in a subprocess:
    
    87
    +    def __test_remote_initially_empty(queue, remote, serialized_digests):
    
    88
    +        channel = grpc.insecure_channel(remote)
    
    89
    +        remote_storage = RemoteStorage(channel, 'testing')
    
    90
    +        digests = []
    
    91
    +
    
    92
    +        for data in serialized_digests:
    
    93
    +            digest = remote_execution_pb2.Digest()
    
    94
    +            digest.ParseFromString(data)
    
    95
    +            digests.append(digest)
    
    96
    +
    
    97
    +        try:
    
    98
    +            __test_initially_empty(remote_storage, digests)
    
    99
    +        except AssertionError:
    
    100
    +            queue.put(False)
    
    101
    +        else:
    
    102
    +            queue.put(True)
    
    103
    +
    
    104
    +    if isinstance(any_storage, str):
    
    105
    +        serialized_digests = [digest.SerializeToString() for digest in digests]
    
    106
    +        assert run_in_subprocess(__test_remote_initially_empty,
    
    107
    +                                 any_storage, serialized_digests)
    
    108
    +    else:
    
    109
    +        __test_initially_empty(any_storage, digests)
    
    110
    +
    
    111
    +
    
    112
    +@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
    
    113
    +def test_basic_write_read(any_storage, blobs_digests):
    
    114
    +    blobs, digests = blobs_digests
    
    115
    +
    
    116
    +    # Actual test function, failing on assertions:
    
    117
    +    def __test_basic_write_read(any_storage, blobs, digests):
    
    118
    +        for blob, digest in zip(blobs, digests):
    
    119
    +            assert not any_storage.has_blob(digest)
    
    120
    +            write(any_storage, digest, blob)
    
    121
    +            assert any_storage.has_blob(digest)
    
    122
    +            assert any_storage.get_blob(digest).read() == blob
    
    123
    +
    
    124
    +            # Try writing the same digest again (since it's valid to do that)
    
    125
    +            write(any_storage, digest, blob)
    
    126
    +            assert any_storage.has_blob(digest)
    
    127
    +            assert any_storage.get_blob(digest).read() == blob
    
    128
    +
    
    129
    +    # Helper test function for remote storage, to be run in a subprocess:
    
    130
    +    def __test_remote_basic_write_read(queue, remote, blobs, serialized_digests):
    
    131
    +        channel = grpc.insecure_channel(remote)
    
    132
    +        remote_storage = RemoteStorage(channel, 'testing')
    
    133
    +        digests = []
    
    134
    +
    
    135
    +        for data in serialized_digests:
    
    136
    +            digest = remote_execution_pb2.Digest()
    
    137
    +            digest.ParseFromString(data)
    
    138
    +            digests.append(digest)
    
    139
    +
    
    140
    +        try:
    
    141
    +            __test_basic_write_read(remote_storage, blobs, digests)
    
    142
    +        except AssertionError:
    
    143
    +            queue.put(False)
    
    144
    +        else:
    
    145
    +            queue.put(True)
    
    146
    +
    
    147
    +    if isinstance(any_storage, str):
    
    148
    +        serialized_digests = [digest.SerializeToString() for digest in digests]
    
    149
    +        assert run_in_subprocess(__test_remote_basic_write_read,
    
    150
    +                                 any_storage, blobs, serialized_digests)
    
    151
    +    else:
    
    152
    +        __test_basic_write_read(any_storage, blobs, digests)
    
    153
    +
    
    154
    +
    
    155
    +@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
    
    156
    +def test_bulk_write_read(any_storage, blobs_digests):
    
    157
    +    blobs, digests = blobs_digests
    
    158
    +
    
    159
    +    # Actual test function, failing on assertions:
    
    160
    +    def __test_bulk_write_read(any_storage, blobs, digests):
    
    161
    +        missing_digests = any_storage.missing_blobs(digests)
    
    162
    +        assert len(missing_digests) == len(digests)
    
    163
    +        for digest in digests:
    
    164
    +            assert digest in missing_digests
    
    165
    +
    
    166
    +        faulty_blobs = list(blobs)
    
    167
    +        faulty_blobs[-1] = b'this-is-not-matching'
    
    168
    +
    
    169
    +        results = any_storage.bulk_update_blobs(list(zip(digests, faulty_blobs)))
    
    170
    +        assert len(results) == len(digests)
    
    171
    +        for result, blob, digest in zip(results[:-1], faulty_blobs[:-1], digests[:-1]):
    
    172
    +            assert result.code == 0
    
    173
    +            assert any_storage.get_blob(digest).read() == blob
    
    174
    +        assert results[-1].code != 0
    
    175
    +
    
    176
    +        missing_digests = any_storage.missing_blobs(digests)
    
    177
    +        assert len(missing_digests) == 1
    
    178
    +        assert missing_digests[0] == digests[-1]
    
    179
    +
    
    180
    +    # Helper test function for remote storage, to be run in a subprocess:
    
    181
    +    def __test_remote_bulk_write_read(queue, remote, blobs, serialized_digests):
    
    182
    +        channel = grpc.insecure_channel(remote)
    
    183
    +        remote_storage = RemoteStorage(channel, 'testing')
    
    184
    +        digests = []
    
    185
    +
    
    186
    +        for data in serialized_digests:
    
    187
    +            digest = remote_execution_pb2.Digest()
    
    188
    +            digest.ParseFromString(data)
    
    189
    +            digests.append(digest)
    
    190
    +
    
    191
    +        try:
    
    192
    +            __test_bulk_write_read(remote_storage, blobs, digests)
    
    193
    +        except AssertionError:
    
    194
    +            queue.put(False)
    
    195
    +        else:
    
    196
    +            queue.put(True)
    
    197
    +
    
    198
    +    if isinstance(any_storage, str):
    
    199
    +        serialized_digests = [digest.SerializeToString() for digest in digests]
    
    200
    +        assert run_in_subprocess(__test_remote_bulk_write_read,
    
    201
    +                                 any_storage, blobs, serialized_digests)
    
    202
    +    else:
    
    203
    +        __test_bulk_write_read(any_storage, blobs, digests)
    
    204
    +
    
    205
    +
    
    206
    +@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
    
    207
    +def test_nonexistent_read(any_storage, blobs_digests):
    
    208
    +    _, digests = blobs_digests
    
    209
    +
    
    210
    +    # Actual test function, failing on assertions:
    
    211
    +    def __test_nonexistent_read(any_storage, digests):
    
    212
    +        for digest in digests:
    
    213
    +            assert any_storage.get_blob(digest) is None
    
    214
    +
    
    215
    +    # Helper test function for remote storage, to be run in a subprocess:
    
    216
    +    def __test_remote_nonexistent_read(queue, remote, serialized_digests):
    
    217
    +        channel = grpc.insecure_channel(remote)
    
    218
    +        remote_storage = RemoteStorage(channel, 'testing')
    
    219
    +        digests = []
    
    220
    +
    
    221
    +        for data in serialized_digests:
    
    222
    +            digest = remote_execution_pb2.Digest()
    
    223
    +            digest.ParseFromString(data)
    
    224
    +            digests.append(digest)
    
    225
    +
    
    226
    +        try:
    
    227
    +            __test_nonexistent_read(remote_storage, digests)
    
    228
    +        except AssertionError:
    
    229
    +            queue.put(False)
    
    230
    +        else:
    
    231
    +            queue.put(True)
    
    232
    +
    
    233
    +    if isinstance(any_storage, str):
    
    234
    +        serialized_digests = [digest.SerializeToString() for digest in digests]
    
    235
    +        assert run_in_subprocess(__test_remote_nonexistent_read,
    
    236
    +                                 any_storage, serialized_digests)
    
    237
    +    else:
    
    238
    +        __test_nonexistent_read(any_storage, digests)
    
    239
    +
    
    240
    +
    
    241
    +@pytest.mark.parametrize('blobs_digests', [(BLOBS[0], BLOBS_DIGESTS[0])])
    
    242
    +def test_lru_eviction(blobs_digests):
    
    243
    +    blobs, digests = blobs_digests
    
    244
    +    blob1, blob2, blob3, *_ = blobs
    
    245
    +    digest1, digest2, digest3, *_ = digests
    
    185 246
     
    
    186
    -def test_lru_eviction():
    
    187 247
         lru = LRUMemoryCache(8)
    
    188
    -    write(lru, abc_digest, abc)
    
    189
    -    write(lru, defg_digest, defg)
    
    190
    -    assert lru.has_blob(abc_digest)
    
    191
    -    assert lru.has_blob(defg_digest)
    
    192
    -
    
    193
    -    write(lru, hijk_digest, hijk)
    
    194
    -    # Check that the LRU evicted abc (it was written first)
    
    195
    -    assert not lru.has_blob(abc_digest)
    
    196
    -    assert lru.has_blob(defg_digest)
    
    197
    -    assert lru.has_blob(hijk_digest)
    
    198
    -
    
    199
    -    assert lru.get_blob(defg_digest).read() == defg
    
    200
    -    write(lru, abc_digest, abc)
    
    201
    -    # Check that the LRU evicted hijk (since we just read defg)
    
    202
    -    assert lru.has_blob(abc_digest)
    
    203
    -    assert lru.has_blob(defg_digest)
    
    204
    -    assert not lru.has_blob(hijk_digest)
    
    205
    -
    
    206
    -    assert lru.has_blob(defg_digest)
    
    207
    -    write(lru, hijk_digest, abc)
    
    208
    -    # Check that the LRU evicted abc (since we just checked hijk)
    
    209
    -    assert not lru.has_blob(abc_digest)
    
    210
    -    assert lru.has_blob(defg_digest)
    
    211
    -    assert lru.has_blob(hijk_digest)
    
    212
    -
    
    213
    -
    
    214
    -def test_with_cache():
    
    248
    +    write(lru, digest1, blob1)
    
    249
    +    write(lru, digest2, blob2)
    
    250
    +    assert lru.has_blob(digest1)
    
    251
    +    assert lru.has_blob(digest2)
    
    252
    +
    
    253
    +    write(lru, digest3, blob3)
    
    254
    +    # Check that the LRU evicted blob1 (it was written first)
    
    255
    +    assert not lru.has_blob(digest1)
    
    256
    +    assert lru.has_blob(digest2)
    
    257
    +    assert lru.has_blob(digest3)
    
    258
    +
    
    259
    +    assert lru.get_blob(digest2).read() == blob2
    
    260
    +    write(lru, digest1, blob1)
    
    261
    +    # Check that the LRU evicted blob3 (since we just read blob2)
    
    262
    +    assert lru.has_blob(digest1)
    
    263
    +    assert lru.has_blob(digest2)
    
    264
    +    assert not lru.has_blob(digest3)
    
    265
    +
    
    266
    +    assert lru.has_blob(digest2)
    
    267
    +    write(lru, digest3, blob1)
    
    268
    +    # Check that the LRU evicted blob1 (since we just checked blob3)
    
    269
    +    assert not lru.has_blob(digest1)
    
    270
    +    assert lru.has_blob(digest2)
    
    271
    +    assert lru.has_blob(digest3)
    
    272
    +
    
    273
    +
    
    274
    +@pytest.mark.parametrize('blobs_digests', [(BLOBS[0], BLOBS_DIGESTS[0])])
    
    275
    +def test_with_cache(blobs_digests):
    
    276
    +    blobs, digests = blobs_digests
    
    277
    +    blob1, blob2, blob3, *_ = blobs
    
    278
    +    digest1, digest2, digest3, *_ = digests
    
    279
    +
    
    215 280
         cache = LRUMemoryCache(256)
    
    216 281
         fallback = LRUMemoryCache(256)
    
    217 282
         with_cache_storage = WithCacheStorage(cache, fallback)
    
    218 283
     
    
    219
    -    assert not with_cache_storage.has_blob(abc_digest)
    
    220
    -    write(with_cache_storage, abc_digest, abc)
    
    221
    -    assert cache.has_blob(abc_digest)
    
    222
    -    assert fallback.has_blob(abc_digest)
    
    223
    -    assert with_cache_storage.get_blob(abc_digest).read() == abc
    
    284
    +    assert not with_cache_storage.has_blob(digest1)
    
    285
    +    write(with_cache_storage, digest1, blob1)
    
    286
    +    assert cache.has_blob(digest1)
    
    287
    +    assert fallback.has_blob(digest1)
    
    288
    +    assert with_cache_storage.get_blob(digest1).read() == blob1
    
    224 289
     
    
    225 290
         # Even if a blob is in cache, we still need to check if the fallback
    
    226 291
         # has it.
    
    227
    -    write(cache, defg_digest, defg)
    
    228
    -    assert not with_cache_storage.has_blob(defg_digest)
    
    229
    -    write(fallback, defg_digest, defg)
    
    230
    -    assert with_cache_storage.has_blob(defg_digest)
    
    292
    +    write(cache, digest2, blob2)
    
    293
    +    assert not with_cache_storage.has_blob(digest2)
    
    294
    +    write(fallback, digest2, blob2)
    
    295
    +    assert with_cache_storage.has_blob(digest2)
    
    231 296
     
    
    232 297
         # When a blob is in the fallback but not the cache, reading it should
    
    233 298
         # put it into the cache.
    
    234
    -    write(fallback, hijk_digest, hijk)
    
    235
    -    assert with_cache_storage.get_blob(hijk_digest).read() == hijk
    
    236
    -    assert cache.has_blob(hijk_digest)
    
    237
    -    assert cache.get_blob(hijk_digest).read() == hijk
    
    238
    -    assert cache.has_blob(hijk_digest)
    299
    +    write(fallback, digest3, blob3)
    
    300
    +    assert with_cache_storage.get_blob(digest3).read() == blob3
    
    301
    +    assert cache.has_blob(digest3)
    
    302
    +    assert cache.get_blob(digest3).read() == blob3
    
    303
    +    assert cache.has_blob(digest3)

  • tests/server_instance.py
    ... ... @@ -13,14 +13,16 @@
    13 13
     # limitations under the License.
    
    14 14
     
    
    15 15
     
    
    16
    +from buildgrid._app.settings import parser
    
    17
    +from buildgrid._app.commands.cmd_server import _create_server_from_config
    
    16 18
     from buildgrid.server.cas.service import ByteStreamService, ContentAddressableStorageService
    
    17 19
     from buildgrid.server.actioncache.service import ActionCacheService
    
    18 20
     from buildgrid.server.execution.service import ExecutionService
    
    19 21
     from buildgrid.server.operations.service import OperationsService
    
    20 22
     from buildgrid.server.bots.service import BotsService
    
    21 23
     from buildgrid.server.referencestorage.service import ReferenceStorageService
    
    22
    -from buildgrid._app.settings import parser
    
    23
    -from buildgrid._app.commands.cmd_server import _create_server_from_config
    
    24
    +
    
    25
    +from .utils.cas import run_in_subprocess
    
    24 26
     
    
    25 27
     
    
    26 28
     config = """
    
    ... ... @@ -69,17 +71,25 @@ instances:
    69 71
     
    
    70 72
     
    
    71 73
     def test_create_server():
    
    72
    -    settings = parser.get_parser().safe_load(config)
    
    73
    -
    
    74
    -    server = _create_server_from_config(settings)
    
    75
    -
    
    76
    -    server.start()
    
    77
    -    server.stop()
    
    78
    -
    
    79
    -    assert isinstance(server._execution_service, ExecutionService)
    
    80
    -    assert isinstance(server._operations_service, OperationsService)
    
    81
    -    assert isinstance(server._bots_service, BotsService)
    
    82
    -    assert isinstance(server._reference_storage_service, ReferenceStorageService)
    
    83
    -    assert isinstance(server._action_cache_service, ActionCacheService)
    
    84
    -    assert isinstance(server._cas_service, ContentAddressableStorageService)
    
    85
    -    assert isinstance(server._bytestream_service, ByteStreamService)
    74
    +    # Actual test function, to be run in a subprocess:
    
    75
    +    def __test_create_server(queue, config_data):
    
    76
    +        settings = parser.get_parser().safe_load(config)
    
    77
    +        server = _create_server_from_config(settings)
    
    78
    +
    
    79
    +        server.start()
    
    80
    +        server.stop()
    
    81
    +
    
    82
    +        try:
    
    83
    +            assert isinstance(server._execution_service, ExecutionService)
    
    84
    +            assert isinstance(server._operations_service, OperationsService)
    
    85
    +            assert isinstance(server._bots_service, BotsService)
    
    86
    +            assert isinstance(server._reference_storage_service, ReferenceStorageService)
    
    87
    +            assert isinstance(server._action_cache_service, ActionCacheService)
    
    88
    +            assert isinstance(server._cas_service, ContentAddressableStorageService)
    
    89
    +            assert isinstance(server._bytestream_service, ByteStreamService)
    
    90
    +        except AssertionError:
    
    91
    +            queue.put(False)
    
    92
    +        else:
    
    93
    +            queue.put(True)
    
    94
    +
    
    95
    +    assert run_in_subprocess(__test_create_server, config)

  • tests/utils/__init__.py

  • tests/utils/cas.py
    1
    +# Copyright (C) 2018 Bloomberg LP
    
    2
    +#
    
    3
    +# Licensed under the Apache License, Version 2.0 (the "License");
    
    4
    +# you may not use this file except in compliance with the License.
    
    5
    +# You may obtain a copy of the License at
    
    6
    +#
    
    7
    +#  <http://www.apache.org/licenses/LICENSE-2.0>
    
    8
    +#
    
    9
    +# Unless required by applicable law or agreed to in writing, software
    
    10
    +# distributed under the License is distributed on an "AS IS" BASIS,
    
    11
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    
    12
    +# See the License for the specific language governing permissions and
    
    13
    +# limitations under the License.
    
    14
    +
    
    15
    +
    
    16
    +from concurrent import futures
    
    17
    +from contextlib import contextmanager
    
    18
    +import multiprocessing
    
    19
    +import os
    
    20
    +import signal
    
    21
    +import tempfile
    
    22
    +
    
    23
    +import grpc
    
    24
    +import psutil
    
    25
    +import pytest_cov
    
    26
    +
    
    27
    +from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
    
    28
    +from buildgrid.server.cas.service import ByteStreamService
    
    29
    +from buildgrid.server.cas.service import ContentAddressableStorageService
    
    30
    +from buildgrid.server.cas.instance import ByteStreamInstance
    
    31
    +from buildgrid.server.cas.instance import ContentAddressableStorageInstance
    
    32
    +from buildgrid.server.cas.storage.disk import DiskStorage
    
    33
    +
    
    34
    +
    
    35
    +@contextmanager
    
    36
    +def serve_cas(instances):
    
    37
    +    server = Server(instances)
    
    38
    +    try:
    
    39
    +        yield server
    
    40
    +    finally:
    
    41
    +        server.quit()
    
    42
    +
    
    43
    +
    
    44
    +def kill_process_tree(pid):
    
    45
    +    proc = psutil.Process(pid)
    
    46
    +    children = proc.children(recursive=True)
    
    47
    +
    
    48
    +    def kill_proc(p):
    
    49
    +        try:
    
    50
    +            p.kill()
    
    51
    +        except psutil.AccessDenied:
    
    52
    +            # Ignore this error, it can happen with
    
    53
    +            # some setuid bwrap processes.
    
    54
    +            pass
    
    55
    +
    
    56
    +    # Bloody Murder
    
    57
    +    for child in children:
    
    58
    +        kill_proc(child)
    
    59
    +    kill_proc(proc)
    
    60
    +
    
    61
    +
    
    62
    +def run_in_subprocess(function, *arguments):
    
    63
    +    queue = multiprocessing.Queue()
    
    64
    +    # Use subprocess to avoid creation of gRPC threads in main process
    
    65
    +    # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md
    
    66
    +    process = multiprocessing.Process(target=function,
    
    67
    +                                      args=(queue, *arguments))
    
    68
    +
    
    69
    +    try:
    
    70
    +        process.start()
    
    71
    +
    
    72
    +        result = queue.get()
    
    73
    +        process.join()
    
    74
    +    except KeyboardInterrupt:
    
    75
    +        kill_process_tree(process.pid)
    
    76
    +        raise
    
    77
    +
    
    78
    +    return result
    
    79
    +
    
    80
    +
    
    81
    +class Server:
    
    82
    +
    
    83
    +    def __init__(self, instances):
    
    84
    +
    
    85
    +        self.instances = instances
    
    86
    +
    
    87
    +        self.__storage_path = tempfile.TemporaryDirectory()
    
    88
    +        self.__storage = DiskStorage(self.__storage_path.name)
    
    89
    +
    
    90
    +        self.__queue = multiprocessing.Queue()
    
    91
    +        self.__process = multiprocessing.Process(
    
    92
    +            target=Server.serve,
    
    93
    +            args=(self.__queue, self.instances, self.__storage_path.name))
    
    94
    +        self.__process.start()
    
    95
    +
    
    96
    +        self.port = self.__queue.get()
    
    97
    +        self.remote = 'localhost:{}'.format(self.port)
    
    98
    +
    
    99
    +    @classmethod
    
    100
    +    def serve(cls, queue, instances, storage_path):
    
    101
    +        pytest_cov.embed.cleanup_on_sigterm()
    
    102
    +
    
    103
    +        # Use max_workers default from Python 3.5+
    
    104
    +        max_workers = (os.cpu_count() or 1) * 5
    
    105
    +        server = grpc.server(futures.ThreadPoolExecutor(max_workers))
    
    106
    +        port = server.add_insecure_port('localhost:0')
    
    107
    +
    
    108
    +        storage = DiskStorage(storage_path)
    
    109
    +
    
    110
    +        bs_service = ByteStreamService(server)
    
    111
    +        cas_service = ContentAddressableStorageService(server)
    
    112
    +        for name in instances:
    
    113
    +            bs_service.add_instance(name, ByteStreamInstance(storage))
    
    114
    +            cas_service.add_instance(name, ContentAddressableStorageInstance(storage))
    
    115
    +
    
    116
    +        server.start()
    
    117
    +        queue.put(port)
    
    118
    +
    
    119
    +        signal.pause()
    
    120
    +
    
    121
    +    def has(self, digest):
    
    122
    +        return self.__storage.has_blob(digest)
    
    123
    +
    
    124
    +    def get(self, digest):
    
    125
    +        return self.__storage.get_blob(digest).read()
    
    126
    +
    
    127
    +    def compare_blobs(self, digest, blob):
    
    128
    +        if not self.__storage.has_blob(digest):
    
    129
    +            return False
    
    130
    +
    
    131
    +        stored_blob = self.__storage.get_blob(digest)
    
    132
    +        stored_blob = stored_blob.read()
    
    133
    +
    
    134
    +        return blob == stored_blob
    
    135
    +
    
    136
    +    def compare_messages(self, digest, message):
    
    137
    +        if not self.__storage.has_blob(digest):
    
    138
    +            return False
    
    139
    +
    
    140
    +        message_blob = message.SerializeToString()
    
    141
    +
    
    142
    +        stored_blob = self.__storage.get_blob(digest)
    
    143
    +        stored_blob = stored_blob.read()
    
    144
    +
    
    145
    +        return message_blob == stored_blob
    
    146
    +
    
    147
    +    def compare_files(self, digest, file_path):
    
    148
    +        if not self.__storage.has_blob(digest):
    
    149
    +            return False
    
    150
    +
    
    151
    +        with open(file_path, 'rb') as file_bytes:
    
    152
    +            file_blob = file_bytes.read()
    
    153
    +
    
    154
    +        stored_blob = self.__storage.get_blob(digest)
    
    155
    +        stored_blob = stored_blob.read()
    
    156
    +
    
    157
    +        return file_blob == stored_blob
    
    158
    +
    
    159
    +    def compare_directories(self, digest, directory_path):
    
    160
    +        if not self.__storage.has_blob(digest):
    
    161
    +            return False
    
    162
    +        elif not os.path.isdir(directory_path):
    
    163
    +            return False
    
    164
    +
    
    165
    +        def __compare_folders(digest, path):
    
    166
    +            directory = remote_execution_pb2.Directory()
    
    167
    +            directory.ParseFromString(self.__storage.get_blob(digest).read())
    
    168
    +
    
    169
    +            files, directories, symlinks = [], [], []
    
    170
    +            for entry in os.scandir(path):
    
    171
    +                if entry.is_file(follow_symlinks=False):
    
    172
    +                    files.append(entry.name)
    
    173
    +
    
    174
    +                elif entry.is_dir(follow_symlinks=False):
    
    175
    +                    directories.append(entry.name)
    
    176
    +
    
    177
    +                elif os.path.islink(entry.path):
    
    178
    +                    symlinks.append(entry.name)
    
    179
    +
    
    180
    +            assert len(files) == len(directory.files)
    
    181
    +            assert len(directories) == len(directory.directories)
    
    182
    +            assert len(symlinks) == len(directory.symlinks)
    
    183
    +
    
    184
    +            for file_node in directory.files:
    
    185
    +                file_path = os.path.join(path, file_node.name)
    
    186
    +
    
    187
    +                assert file_node.name in files
    
    188
    +                assert os.path.isfile(file_path)
    
    189
    +                assert not os.path.islink(file_path)
    
    190
    +                if file_node.is_executable:
    
    191
    +                    assert os.access(file_path, os.X_OK)
    
    192
    +
    
    193
    +                assert self.compare_files(file_node.digest, file_path)
    
    194
    +
    
    195
    +            for directory_node in directory.directories:
    
    196
    +                directory_path = os.path.join(path, directory_node.name)
    
    197
    +
    
    198
    +                assert directory_node.name in directories
    
    199
    +                assert os.path.exists(directory_path)
    
    200
    +                assert not os.path.islink(directory_path)
    
    201
    +
    
    202
    +                assert __compare_folders(directory_node.digest, directory_path)
    
    203
    +
    
    204
    +            for symlink_node in directory.symlinks:
    
    205
    +                symlink_path = os.path.join(path, symlink_node.name)
    
    206
    +
    
    207
    +                assert symlink_node.name in symlinks
    
    208
    +                assert os.path.islink(symlink_path)
    
    209
    +                assert os.readlink(symlink_path) == symlink_node.target
    
    210
    +
    
    211
    +            return True
    
    212
    +
    
    213
    +        return __compare_folders(digest, directory_path)
    
    214
    +
    
    215
    +    def quit(self):
    
    216
    +        if self.__process:
    
    217
    +            self.__process.terminate()
    
    218
    +            self.__process.join()
    
    219
    +
    
    220
    +        self.__storage_path.cleanup()



  • [Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]