Source code for handtruck.credentials

  1import configparser
  2import contextlib
  3import datetime
  4import logging
  5import math
  6import os
  7from abc import ABC, abstractmethod
  8from dataclasses import dataclass
  9from functools import cached_property
 10from pathlib import Path
 11from typing import Any, List, Mapping, Optional, Tuple, TypedDict, Union
 12
 13import anyio
 14from aws_request_signer import AwsRequestSigner
 15from httpx import URL, AsyncClient
 16
 17
 18log = logging.getLogger(__name__)
 19
 20
[docs] 21class AbstractCredentials(ABC): 22 @abstractmethod 23 def __bool__(self) -> bool: 24 ... 25 26 @property 27 @abstractmethod 28 def signer(self) -> AwsRequestSigner: 29 ...
30 31
[docs] 32@dataclass(frozen=True) 33class StaticCredentials(AbstractCredentials): 34 """ 35 Simple, constant credentials, passed in from elsewhere. 36 """ 37 access_key_id: str = "" 38 secret_access_key: str = "" 39 session_token: Optional[str] = None 40 region: str = "" 41 service: str = "s3" 42 43 def __bool__(self) -> bool: 44 return all((self.access_key_id, self.secret_access_key)) 45 46 def __repr__(self) -> str: 47 return ( 48 f"{self.__class__.__name__}(access_key_id={self.access_key_id!r}, " 49 "secret_access_key=" 50 f'{"******" if self.secret_access_key else None!r}, ' 51 f"region={self.region!r}, service={self.service!r})" 52 ) 53
[docs] 54 def as_dict(self) -> dict: 55 return { 56 "region": self.region, 57 "access_key_id": self.access_key_id, 58 "secret_access_key": self.secret_access_key, 59 "session_token": self.session_token, 60 "service": self.service, 61 }
62
[docs] 63 @cached_property 64 def signer(self) -> AwsRequestSigner: 65 return AwsRequestSigner(**self.as_dict())
66 67
[docs] 68class URLCredentials(StaticCredentials): 69 """ 70 Credentials, where the access key and secret are in the URL. 71 """ 72 def __init__( 73 self, url: Union[str, URL], *, region: str = "", service: str = "s3", 74 ): 75 url = URL(url) 76 super().__init__( 77 access_key_id=url.username or "", 78 secret_access_key=url.password or "", 79 region=region, service=service, 80 )
81 82
[docs] 83class EnvironmentCredentials(StaticCredentials): 84 """ 85 Credentials, read in from the standard environment variables. 86 """ 87 def __init__(self, region: str = "", service: str = "s3"): 88 super().__init__( 89 access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""), 90 secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""), 91 session_token=os.getenv("AWS_SESSION_TOKEN"), 92 region=os.getenv("AWS_DEFAULT_REGION", region), 93 service=service, 94 )
95 96
[docs] 97class ConfigCredentials(StaticCredentials): 98 """ 99 Credentials from the awscli config 100 101 By default, these come from ``~/.aws/credentials`` and ``~/.aws/config``, 102 but this is customizable with ``$AWS_SHARED_CREDENTIALS_FILE``, 103 ``$AWS_SHARED_CONFIG_FILE``, and ``$AWS_PROFILE``. 104 """ 105 DEFAULT_CREDENTIALS_PATH = Path.home() / ".aws" / "credentials" 106 DEFAULT_CONFIG_PATH = Path.home() / ".aws" / "config" 107 108 @staticmethod 109 def _parse_ini_section(path: Path, section: str) -> Mapping[str, str]: 110 conf = configparser.ConfigParser() 111 if not conf.read(path): 112 return {} 113 114 if section not in conf: 115 return {} 116 117 return conf[section] 118 119 def __init__( 120 self, 121 credentials_path: Union[str, Path, None] = None, 122 config_path: Union[str, Path, None] = DEFAULT_CONFIG_PATH, *, 123 region: str = "", service: str = "s3", profile: str = "auto", 124 ): 125 if credentials_path is None: 126 credentials_path = Path( 127 os.getenv( 128 "AWS_SHARED_CREDENTIALS_FILE", 129 self.DEFAULT_CREDENTIALS_PATH, 130 ), 131 ) 132 credentials_path = Path(credentials_path) 133 134 if config_path is None: 135 config_path = Path( 136 os.getenv( 137 "AWS_SHARED_CONFIG_FILE", 138 self.DEFAULT_CONFIG_PATH, 139 ), 140 ) 141 config_path = Path(config_path) 142 143 try: 144 credentials_paths_exists = ( 145 credentials_path.exists() and config_path.exists() 146 ) 147 except OSError: 148 credentials_paths_exists = False 149 150 if not credentials_paths_exists: 151 super().__init__(region=region, service=service) 152 return 153 154 if profile == "auto": 155 profile = os.getenv("AWS_PROFILE", "default") 156 157 section = self._parse_ini_section(credentials_path, profile) 158 access_key_id = section.get("aws_access_key_id", "") 159 secret_access_key = section.get("aws_secret_access_key", "") 160 161 section = self._parse_ini_section(config_path, profile) 162 region = section.get("region", "") 163 164 super().__init__( 165 access_key_id=access_key_id, 166 secret_access_key=secret_access_key, 167 region=region, 168 service=service, 169 )
170 171 172ENVIRONMENT_CREDENTIALS = EnvironmentCredentials() 173 174
[docs] 175def merge_credentials(*credentials: StaticCredentials) -> StaticCredentials: 176 """ 177 Reads a bunch of credentials and produces a :class:`StaticCredentials`. 178 """ 179 result = {} 180 fields = ( 181 "access_key_id", "secret_access_key", 182 "session_token", "region", "service", 183 ) 184 185 for candidate in credentials: 186 for field in fields: 187 if field in result: 188 continue 189 value = getattr(candidate, field, None) 190 if not value: 191 continue 192 result[field] = value 193 194 return StaticCredentials(**result)
195 196
[docs] 197def collect_credentials( 198 *, url: Optional[URL] = None, **kwargs, 199) -> StaticCredentials: 200 """ 201 Calls :func:`merge_credentials()` on the set of standard credential sources. 202 203 If you want to emulate awscli, use this. 204 """ 205 credentials: List[StaticCredentials] = [] 206 if kwargs: 207 credentials.append(StaticCredentials(**kwargs)) 208 if url: 209 credentials.append(URLCredentials(url)) 210 credentials.append(EnvironmentCredentials()) 211 credentials.append(ConfigCredentials()) 212 return merge_credentials(*credentials)
213 214 215class MetadataDocument(TypedDict, total=False): 216 """ 217 Response example is: 218 219 { 220 "accountId" : "123123", 221 "architecture" : "x86_64", 222 "availabilityZone" : "us-east-1a", 223 "billingProducts" : null, 224 "devpayProductCodes" : null, 225 "marketplaceProductCodes" : null, 226 "imageId" : "ami-123123", 227 "instanceId" : "i-11232323", 228 "instanceType" : "t3a.micro", 229 "kernelId" : null, 230 "pendingTime" : "2023-06-13T18:18:58Z", 231 "privateIp" : "172.33.33.33", 232 "ramdiskId" : null, 233 "region" : "us-east-1", 234 "version" : "2017-09-30" 235 } 236 237 :meta private: 238 """ 239 region: str 240 241 242class MetadataSecurityCredentials(TypedDict, total=False): 243 """ 244 :meta private: 245 """ 246 Code: str 247 Type: str 248 AccessKeyId: str 249 SecretAccessKey: str 250 Token: str 251 Expiration: str 252 253
[docs] 254class MetadataCredentials(AbstractCredentials, anyio.AsyncContextManagerMixin): 255 """ 256 Credentials from the AWS metadata service. 257 258 This only works inside of AWS, and unlike everything else, is dynamic and 259 kept fresh. The context manager manages this process:: 260 261 with MetadataCredentials() as creds: 262 client = S3Client(..., credentials=creds) 263 """ 264 METADATA_ADDRESS: str = "169.254.169.254" 265 METADATA_PORT: int = 80 266 267 def __init__(self, *, service: str = "s3"): 268 self.session = AsyncClient( 269 base_url=URL( 270 scheme="http", 271 host=self.METADATA_ADDRESS.rstrip('/'), 272 port=self.METADATA_PORT, 273 ), 274 ) 275 self.service = service 276 #: :meta private: 277 self.refresh_lock: anyio.Lock = anyio.Lock() 278 self._signer: Optional[AwsRequestSigner] = None 279 280 @contextlib.asynccontextmanager 281 async def __asynccontextmanager__(self): 282 async with anyio.create_task_group() as taskgroup: 283 await taskgroup.start( 284 self._refresher, name="MetadataCredentials-refresher" 285 ) 286 yield self 287 taskgroup.cancel_scope.cancel() 288 289 def __bool__(self) -> bool: 290 return self._signer is not None 291 292 async def _refresher(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: 293 while True: 294 async with self.refresh_lock: 295 try: 296 credentials, expires_at = await self._fetch_credentials() 297 self._signer = AwsRequestSigner(**credentials.as_dict()) 298 delta = expires_at - datetime.datetime.now(datetime.UTC) 299 sleep_time = math.floor(delta.total_seconds() / 2) 300 task_status.started() 301 except Exception as ex: 302 log.exception("Failed to update credentials", exc_info=ex) 303 sleep_time = 60 304 await anyio.sleep(sleep_time) 305 306 async def _fetch_credentials( 307 self, 308 ) -> Tuple[StaticCredentials, datetime.datetime]: 309 response = await self.session.get("/latest/dynamic/instance-identity/document") 310 document: MetadataDocument = response.json() 311 312 response = await self.session.get("/latest/meta-data/iam/security-credentials/") 313 iam_role = response.content.decode() 314 315 response = await self.session.get(f"/latest/meta-data/iam/security-credentials/{iam_role}") 316 credentials: MetadataSecurityCredentials = response.json() 317 318 return ( 319 StaticCredentials( 320 region=document["region"], 321 access_key_id=credentials["AccessKeyId"], 322 secret_access_key=credentials["SecretAccessKey"], 323 session_token=credentials["Token"], 324 ), 325 datetime.datetime.fromisoformat(credentials["Expiration"]), 326 ) 327 328 @property 329 def signer(self) -> AwsRequestSigner: 330 if not self._signer: 331 raise RuntimeError( 332 f"{self.__class__.__name__} must be started before using", 333 ) 334 return self._signer
335 336 337__all__ = ( 338 "AbstractCredentials", 339 "ConfigCredentials", 340 "EnvironmentCredentials", 341 "MetadataCredentials", 342 "StaticCredentials", 343 "URLCredentials", 344 "collect_credentials", 345 "merge_credentials", 346)