[Notes] [Git][BuildStream/buildstream][juerg/cas-batch] _artifactcache/cascache.py: Use BatchReadBlobs



Title: GitLab

Jürg Billeter pushed to branch juerg/cas-batch at BuildStream / buildstream

Commits:

1 changed file:

Changes:

  • buildstream/_artifactcache/cascache.py
    ... ... @@ -884,6 +884,56 @@ class CASCache(ArtifactCache):
    884 884
     
    
    885 885
             return objpath
    
    886 886
     
    
    887
    +    def _batch_download_complete(self, batch):
    
    888
    +        for digest, data in batch.send():
    
    889
    +            with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
    
    890
    +                f.write(data)
    
    891
    +                f.flush()
    
    892
    +
    
    893
    +                added_digest = self.add_object(path=f.name)
    
    894
    +                assert added_digest.hash == digest.hash
    
    895
    +
    
    896
    +    # Helper function for _fetch_directory().
    
    897
    +    def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue):
    
    898
    +        self._batch_download_complete(batch)
    
    899
    +
    
    900
    +        # All previously scheduled directories are now locally available,
    
    901
    +        # move them to the processing queue.
    
    902
    +        fetch_queue.extend(fetch_next_queue)
    
    903
    +        fetch_next_queue.clear()
    
    904
    +        return _CASBatchRead(remote)
    
    905
    +
    
    906
    +    # Helper function for _fetch_directory().
    
    907
    +    def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False):
    
    908
    +        if os.path.exists(self.objpath(digest)):
    
    909
    +            # Skip download, already in local cache.
    
    910
    +            if recursive:
    
    911
    +                # Add directory to processing queue.
    
    912
    +                fetch_queue.append(digest)
    
    913
    +            return batch
    
    914
    +
    
    915
    +        if (digest.size_bytes >= remote.max_batch_total_size_bytes or
    
    916
    +                not remote.batch_read_supported):
    
    917
    +            # Too large for batch request, download in independent request.
    
    918
    +            self._ensure_blob(remote, digest)
    
    919
    +
    
    920
    +            if recursive:
    
    921
    +                # Add directory to processing queue.
    
    922
    +                fetch_queue.append(digest)
    
    923
    +        else:
    
    924
    +            if not batch.add(digest):
    
    925
    +                # Not enough space left in batch request.
    
    926
    +                # Complete pending batch first.
    
    927
    +                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
    
    928
    +                batch.add(digest)
    
    929
    +
    
    930
    +            if recursive:
    
    931
    +                # Directory will be available after completing pending batch.
    
    932
    +                # Add directory to deferred processing queue.
    
    933
    +                fetch_next_queue.append(digest)
    
    934
    +
    
    935
    +        return batch
    
    936
    +
    
    887 937
         # _fetch_directory():
    
    888 938
         #
    
    889 939
         # Fetches remote directory and adds it to content addressable store.
    
    ... ... @@ -897,23 +947,32 @@ class CASCache(ArtifactCache):
    897 947
         #     dir_digest (Digest): Digest object for the directory to fetch.
    
    898 948
         #
    
    899 949
         def _fetch_directory(self, remote, dir_digest):
    
    900
    -        objpath = self.objpath(dir_digest)
    
    901
    -        if os.path.exists(objpath):
    
    902
    -            # already in local cache
    
    903
    -            return
    
    950
    +        fetch_queue = [dir_digest]
    
    951
    +        fetch_next_queue = []
    
    952
    +        batch = _CASBatchRead(remote)
    
    904 953
     
    
    905
    -        objpath = self._ensure_blob(remote, dir_digest)
    
    954
    +        while len(fetch_queue) + len(fetch_next_queue) > 0:
    
    955
    +            if len(fetch_queue) == 0:
    
    956
    +                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
    
    906 957
     
    
    907
    -        directory = remote_execution_pb2.Directory()
    
    958
    +            dir_digest = fetch_queue.pop(0)
    
    908 959
     
    
    909
    -        with open(objpath, 'rb') as f:
    
    910
    -            directory.ParseFromString(f.read())
    
    960
    +            objpath = self._ensure_blob(remote, dir_digest)
    
    911 961
     
    
    912
    -        for filenode in directory.files:
    
    913
    -            self._ensure_blob(remote, filenode.digest)
    
    962
    +            directory = remote_execution_pb2.Directory()
    
    963
    +            with open(objpath, 'rb') as f:
    
    964
    +                directory.ParseFromString(f.read())
    
    914 965
     
    
    915
    -        for dirnode in directory.directories:
    
    916
    -            self._fetch_directory(remote, dirnode.digest)
    
    966
    +            for dirnode in directory.directories:
    
    967
    +                batch = self._fetch_directory_node(remote, dirnode.digest, batch,
    
    968
    +                    fetch_queue, fetch_next_queue, recursive=True)
    
    969
    +
    
    970
    +            for filenode in directory.files:
    
    971
    +                batch = self._fetch_directory_node(remote, filenode.digest, batch,
    
    972
    +                    fetch_queue, fetch_next_queue)
    
    973
    +
    
    974
    +        # Fetch final batch
    
    975
    +        self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
    
    917 976
     
    
    918 977
         def _fetch_tree(self, remote, digest):
    
    919 978
             # download but do not store the Tree object
    
    ... ... @@ -1040,11 +1099,78 @@ class _CASRemote():
    1040 1099
     
    
    1041 1100
                 self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
    
    1042 1101
                 self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
    
    1102
    +            self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
    
    1043 1103
                 self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
    
    1044 1104
     
    
    1105
    +            self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
    
    1106
    +            try:
    
    1107
    +                request = remote_execution_pb2.GetCapabilitiesRequest()
    
    1108
    +                response = self.capabilities.GetCapabilities(request)
    
    1109
    +                server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
    
    1110
    +                if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
    
    1111
    +                    self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
    
    1112
    +            except grpc.RpcError as e:
    
    1113
    +                # Simply use the defaults for servers that don't implement GetCapabilities()
    
    1114
    +                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    1115
    +                    raise
    
    1116
    +
    
    1117
    +            # Check whether the server supports BatchReadBlobs()
    
    1118
    +            self.batch_read_supported = False
    
    1119
    +            try:
    
    1120
    +                request = remote_execution_pb2.BatchReadBlobsRequest()
    
    1121
    +                response = self.cas.BatchReadBlobs(request)
    
    1122
    +                self.batch_read_supported = True
    
    1123
    +            except grpc.RpcError as e:
    
    1124
    +                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
    
    1125
    +                    raise
    
    1126
    +
    
    1045 1127
                 self._initialized = True
    
    1046 1128
     
    
    1047 1129
     
    
    1130
    +# Represents a batch of blobs queued for fetching.
    
    1131
    +#
    
    1132
    +class _CASBatchRead():
    
    1133
    +    def __init__(self, remote):
    
    1134
    +        self._remote = remote
    
    1135
    +        self._max_total_size_bytes = remote.max_batch_total_size_bytes
    
    1136
    +        self._request = remote_execution_pb2.BatchReadBlobsRequest()
    
    1137
    +        self._size = 0
    
    1138
    +        self._sent = False
    
    1139
    +
    
    1140
    +    def add(self, digest):
    
    1141
    +        assert not self._sent
    
    1142
    +
    
    1143
    +        new_batch_size = self._size + digest.size_bytes
    
    1144
    +        if new_batch_size > self._max_total_size_bytes:
    
    1145
    +            # Not enough space left in current batch
    
    1146
    +            return False
    
    1147
    +
    
    1148
    +        request_digest = self._request.digests.add()
    
    1149
    +        request_digest.hash = digest.hash
    
    1150
    +        request_digest.size_bytes = digest.size_bytes
    
    1151
    +        self._size = new_batch_size
    
    1152
    +        return True
    
    1153
    +
    
    1154
    +    def send(self):
    
    1155
    +        assert not self._sent
    
    1156
    +        self._sent = True
    
    1157
    +
    
    1158
    +        if len(self._request.digests) == 0:
    
    1159
    +            return
    
    1160
    +
    
    1161
    +        batch_response = self._remote.cas.BatchReadBlobs(self._request)
    
    1162
    +
    
    1163
    +        for response in batch_response.responses:
    
    1164
    +            if response.status.code != grpc.StatusCode.OK.value[0]:
    
    1165
    +                raise ArtifactError("Failed to download blob {}: {}".format(
    
    1166
    +                    response.digest.hash, response.status.code))
    
    1167
    +            if response.digest.size_bytes != len(response.data):
    
    1168
    +                raise ArtifactError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
    
    1169
    +                    response.digest.hash, response.digest.size_bytes, len(response.data)))
    
    1170
    +
    
    1171
    +            yield (response.digest, response.data)
    
    1172
    +
    
    1173
    +
    
    1048 1174
     def _grouper(iterable, n):
    
    1049 1175
         while True:
    
    1050 1176
             try:
    



  • [Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]