|
1
|
+# Copyright (C) 2018 Bloomberg LP
|
|
2
|
+#
|
|
3
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+# you may not use this file except in compliance with the License.
|
|
5
|
+# You may obtain a copy of the License at
|
|
6
|
+#
|
|
7
|
+# <http://www.apache.org/licenses/LICENSE-2.0>
|
|
8
|
+#
|
|
9
|
+# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+# See the License for the specific language governing permissions and
|
|
13
|
+# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+from datetime import datetime
|
|
17
|
+from enum import Enum
|
|
18
|
+
|
|
19
|
+import grpc
|
|
20
|
+import jwt
|
|
21
|
+
|
|
22
|
+from buildgrid._exceptions import InvalidArgumentError
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+class JwtAlgorithm(Enum):
|
|
26
|
+ # HMAC algorithms:
|
|
27
|
+ HS256 = 'HS256'
|
|
28
|
+ HS384 = 'HS384'
|
|
29
|
+ HS512 = 'HS512'
|
|
30
|
+
|
|
31
|
+ # RSASSA-PKCS algorithms:
|
|
32
|
+ RS256 = 'RS256'
|
|
33
|
+ RS384 = 'RS384'
|
|
34
|
+ RS512 = 'RS512'
|
|
35
|
+
|
|
36
|
+ # RSASSA-PSS algorithms:
|
|
37
|
+ PS256 = 'PS256'
|
|
38
|
+ PS384 = 'PS384'
|
|
39
|
+ PS512 = 'PS512'
|
|
40
|
+
|
|
41
|
+ # ECDSA algorithms:
|
|
42
|
+ ES256 = 'ES256'
|
|
43
|
+ ES384 = 'ES384'
|
|
44
|
+ ES521 = 'ES521'
|
|
45
|
+ ES512 = 'ES512'
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+class JwtAuthMetadataInterceptor(grpc.ServerInterceptor):
|
|
49
|
+
|
|
50
|
+ __auth_errors = {
|
|
51
|
+ 'missing-bearer': 'Missing authentication header field.',
|
|
52
|
+ 'invalid-bearer': 'Invalid authentication header field.',
|
|
53
|
+ 'invalid-token': 'Invalid authentication token.',
|
|
54
|
+ 'expired-token': 'Expired authentication token.',
|
|
55
|
+ 'unbounded-token': 'Unbounded authentication token.',
|
|
56
|
+ }
|
|
57
|
+
|
|
58
|
+ def __init__(self, secret, algorithm):
|
|
59
|
+ """Initialises a new :class:`JwtAuthMetadataInterceptor`.
|
|
60
|
+
|
|
61
|
+ Args:
|
|
62
|
+ secret (str): Symetric secret key or asymetric public key.
|
|
63
|
+ algorithm (JwtAlgorithm): Algorithm used to encode `secret`.
|
|
64
|
+
|
|
65
|
+ Raises:
|
|
66
|
+ InvalidArgumentError: If `algorithm` is not supported.
|
|
67
|
+ """
|
|
68
|
+ self.__bearer_cache = {}
|
|
69
|
+ self.__terminators = {}
|
|
70
|
+ self.__secret = secret
|
|
71
|
+
|
|
72
|
+ self._algorithm = algorithm.value
|
|
73
|
+
|
|
74
|
+ try:
|
|
75
|
+ jwt.register_algorithm(self._algorithm, None)
|
|
76
|
+
|
|
77
|
+ except TypeError:
|
|
78
|
+ raise InvalidArgumentError('Algorithm not supported for JWT decoding: [{}]'
|
|
79
|
+ .format(self._algorithm))
|
|
80
|
+
|
|
81
|
+ except ValueError:
|
|
82
|
+ pass
|
|
83
|
+
|
|
84
|
+ for code, message in self.__auth_errors.items():
|
|
85
|
+ self.__terminators[code] = _unary_unary_rpc_terminator(message)
|
|
86
|
+
|
|
87
|
+ @property
|
|
88
|
+ def algorithm(self):
|
|
89
|
+ return JwtAlgorithm(self._algorithm)
|
|
90
|
+
|
|
91
|
+ def intercept_service(self, continuation, handler_call_details):
|
|
92
|
+ try:
|
|
93
|
+ # Reject requests not carrying a token:
|
|
94
|
+ bearer = dict(handler_call_details.invocation_metadata)['Authorization']
|
|
95
|
+
|
|
96
|
+ except KeyError:
|
|
97
|
+ return self.__terminators['missing-bearer'] # Rejected
|
|
98
|
+
|
|
99
|
+ # Reject requests with malformated bearer:
|
|
100
|
+ if not bearer.startswith('Bearer '):
|
|
101
|
+ return self.__terminators['invalid-bearer'] # Rejected
|
|
102
|
+
|
|
103
|
+ try:
|
|
104
|
+ # Hit the cache for already validated token:
|
|
105
|
+ expiration_time = self.__bearer_cache[bearer]
|
|
106
|
+
|
|
107
|
+ # Accept request if cached token hasn't expired yet:
|
|
108
|
+ if expiration_time < datetime.utcnow():
|
|
109
|
+ return continuation(handler_call_details) # Accepted
|
|
110
|
+
|
|
111
|
+ except KeyError:
|
|
112
|
+ pass
|
|
113
|
+
|
|
114
|
+ try:
|
|
115
|
+ # Decode and validate the new token:
|
|
116
|
+ payload = jwt.decode(bearer[7:], self.__secret, algorithm=self._algorithm)
|
|
117
|
+
|
|
118
|
+ except jwt.exceptions.ExpiredSignatureError:
|
|
119
|
+ return self.__terminators['expired-token'] # Rejected
|
|
120
|
+
|
|
121
|
+ except jwt.exceptions.InvalidTokenError:
|
|
122
|
+ return self.__terminators['invalid-token'] # Rejected
|
|
123
|
+
|
|
124
|
+ # Do not accept token without an expiration time:
|
|
125
|
+ if 'exp' not in payload or not isinstance(payload['exp'], int):
|
|
126
|
+ return self.__terminators['unbounded-token'] # Rejected
|
|
127
|
+
|
|
128
|
+ # Cache the validated token and store expiration time:
|
|
129
|
+ self.__bearer_cache[bearer] = datetime.fromtimestamp(payload['exp'])
|
|
130
|
+
|
|
131
|
+ return continuation(handler_call_details) # Accepted
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+def _unary_unary_rpc_terminator(details):
|
|
135
|
+
|
|
136
|
+ def terminate(ignored_request, context):
|
|
137
|
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, details)
|
|
138
|
+
|
|
139
|
+ return grpc.unary_unary_rpc_method_handler(terminate)
|