Source code for handtruck.credentials

  1import configparser
  2import contextlib
  3import datetime
  4import logging
  5import math
  6import os
  7from abc import ABC, abstractmethod
  8from dataclasses import dataclass
  9from functools import cached_property
 10from pathlib import Path
 11from typing import Any, List, Mapping, Optional, Tuple, TypedDict, Union
 12
 13import anyio
 14from aws_request_signer import AwsRequestSigner
 15from httpx import URL, AsyncClient
 16
 17
 18log = logging.getLogger(__name__)
 19
 20
[docs] 21class AbstractCredentials(ABC): 22 @abstractmethod 23 def __bool__(self) -> bool: 24 ... 25 26 @property 27 @abstractmethod 28 def signer(self) -> AwsRequestSigner: 29 ...
30 31
[docs] 32@dataclass(frozen=True) 33class StaticCredentials(AbstractCredentials): 34 access_key_id: str = "" 35 secret_access_key: str = "" 36 session_token: Optional[str] = None 37 region: str = "" 38 service: str = "s3" 39 40 def __bool__(self) -> bool: 41 return all((self.access_key_id, self.secret_access_key)) 42 43 def __repr__(self) -> str: 44 return ( 45 f"{self.__class__.__name__}(access_key_id={self.access_key_id!r}, " 46 "secret_access_key=" 47 f'{"******" if self.secret_access_key else None!r}, ' 48 f"region={self.region!r}, service={self.service!r})" 49 ) 50
[docs] 51 def as_dict(self) -> dict: 52 return { 53 "region": self.region, 54 "access_key_id": self.access_key_id, 55 "secret_access_key": self.secret_access_key, 56 "session_token": self.session_token, 57 "service": self.service, 58 }
59
[docs] 60 @cached_property 61 def signer(self) -> AwsRequestSigner: 62 return AwsRequestSigner(**self.as_dict())
63 64
[docs] 65class URLCredentials(StaticCredentials): 66 def __init__( 67 self, url: Union[str, URL], *, region: str = "", service: str = "s3", 68 ): 69 url = URL(url) 70 super().__init__( 71 access_key_id=url.username or "", 72 secret_access_key=url.password or "", 73 region=region, service=service, 74 )
75 76
[docs] 77class EnvironmentCredentials(StaticCredentials): 78 def __init__(self, region: str = "", service: str = "s3"): 79 super().__init__( 80 access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""), 81 secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""), 82 session_token=os.getenv("AWS_SESSION_TOKEN"), 83 region=os.getenv("AWS_DEFAULT_REGION", region), 84 service=service, 85 )
86 87
[docs] 88class ConfigCredentials(StaticCredentials): 89 DEFAULT_CREDENTIALS_PATH = Path.home() / ".aws" / "credentials" 90 DEFAULT_CONFIG_PATH = Path.home() / ".aws" / "config" 91 92 @staticmethod 93 def _parse_ini_section(path: Path, section: str) -> Mapping[str, str]: 94 conf = configparser.ConfigParser() 95 if not conf.read(path): 96 return {} 97 98 if section not in conf: 99 return {} 100 101 return conf[section] 102 103 def __init__( 104 self, 105 credentials_path: Union[str, Path, None] = None, 106 config_path: Union[str, Path, None] = DEFAULT_CONFIG_PATH, *, 107 region: str = "", service: str = "s3", profile: str = "auto", 108 ): 109 if credentials_path is None: 110 credentials_path = Path( 111 os.getenv( 112 "AWS_SHARED_CREDENTIALS_FILE", 113 self.DEFAULT_CREDENTIALS_PATH, 114 ), 115 ) 116 credentials_path = Path(credentials_path) 117 118 if config_path is None: 119 config_path = Path( 120 os.getenv( 121 "AWS_SHARED_CONFIG_FILE", 122 self.DEFAULT_CONFIG_PATH, 123 ), 124 ) 125 config_path = Path(config_path) 126 127 try: 128 credentials_paths_exists = ( 129 credentials_path.exists() and config_path.exists() 130 ) 131 except OSError: 132 credentials_paths_exists = False 133 134 if not credentials_paths_exists: 135 super().__init__(region=region, service=service) 136 return 137 138 if profile == "auto": 139 profile = os.getenv("AWS_PROFILE", "default") 140 141 section = self._parse_ini_section(credentials_path, profile) 142 access_key_id = section.get("aws_access_key_id", "") 143 secret_access_key = section.get("aws_secret_access_key", "") 144 145 section = self._parse_ini_section(config_path, profile) 146 region = section.get("region", "") 147 148 super().__init__( 149 access_key_id=access_key_id, 150 secret_access_key=secret_access_key, 151 region=region, 152 service=service, 153 )
154 155 156ENVIRONMENT_CREDENTIALS = EnvironmentCredentials() 157 158
[docs] 159def merge_credentials(*credentials: StaticCredentials) -> StaticCredentials: 160 result = {} 161 fields = ( 162 "access_key_id", "secret_access_key", 163 "session_token", "region", "service", 164 ) 165 166 for candidate in credentials: 167 for field in fields: 168 if field in result: 169 continue 170 value = getattr(candidate, field, None) 171 if not value: 172 continue 173 result[field] = value 174 175 return StaticCredentials(**result)
176 177
[docs] 178def collect_credentials( 179 *, url: Optional[URL] = None, **kwargs, 180) -> StaticCredentials: 181 credentials: List[StaticCredentials] = [] 182 if kwargs: 183 credentials.append(StaticCredentials(**kwargs)) 184 if url: 185 credentials.append(URLCredentials(url)) 186 credentials.append(EnvironmentCredentials()) 187 credentials.append(ConfigCredentials()) 188 return merge_credentials(*credentials)
189 190 191class MetadataDocument(TypedDict, total=False): 192 """ 193 Response example is: 194 195 { 196 "accountId" : "123123", 197 "architecture" : "x86_64", 198 "availabilityZone" : "us-east-1a", 199 "billingProducts" : null, 200 "devpayProductCodes" : null, 201 "marketplaceProductCodes" : null, 202 "imageId" : "ami-123123", 203 "instanceId" : "i-11232323", 204 "instanceType" : "t3a.micro", 205 "kernelId" : null, 206 "pendingTime" : "2023-06-13T18:18:58Z", 207 "privateIp" : "172.33.33.33", 208 "ramdiskId" : null, 209 "region" : "us-east-1", 210 "version" : "2017-09-30" 211 } 212 """ 213 region: str 214 215 216class MetadataSecurityCredentials(TypedDict, total=False): 217 Code: str 218 Type: str 219 AccessKeyId: str 220 SecretAccessKey: str 221 Token: str 222 Expiration: str 223 224
[docs] 225class MetadataCredentials(AbstractCredentials, anyio.AsyncContextManagerMixin): 226 METADATA_ADDRESS: str = "169.254.169.254" 227 METADATA_PORT: int = 80 228 229 def __init__(self, *, service: str = "s3"): 230 self.session = AsyncClient( 231 base_url=URL( 232 scheme="http", 233 host=self.METADATA_ADDRESS.rstrip('/'), 234 port=self.METADATA_PORT, 235 ), 236 ) 237 self.service = service 238 self.refresh_lock: anyio.Lock = anyio.Lock() 239 self._signer: Optional[AwsRequestSigner] = None 240 241 @contextlib.asynccontextmanager 242 async def __asynccontextmanager__(self): 243 async with anyio.create_task_group() as taskgroup: 244 await taskgroup.start( 245 self._refresher, name="MetadataCredentials-refresher" 246 ) 247 yield self 248 taskgroup.cancel_scope.cancel() 249 250 def __bool__(self) -> bool: 251 return self._signer is not None 252 253 async def _refresher(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: 254 while True: 255 async with self.refresh_lock: 256 try: 257 credentials, expires_at = await self._fetch_credentials() 258 self._signer = AwsRequestSigner(**credentials.as_dict()) 259 delta = expires_at - datetime.datetime.now(datetime.UTC) 260 sleep_time = math.floor(delta.total_seconds() / 2) 261 task_status.started() 262 except Exception as ex: 263 log.exception("Failed to update credentials", exc_info=ex) 264 sleep_time = 60 265 await anyio.sleep(sleep_time) 266 267 async def _fetch_credentials( 268 self, 269 ) -> Tuple[StaticCredentials, datetime.datetime]: 270 response = await self.session.get("/latest/dynamic/instance-identity/document") 271 document: MetadataDocument = response.json() 272 273 response = await self.session.get("/latest/meta-data/iam/security-credentials/") 274 iam_role = response.content.decode() 275 276 response = await self.session.get(f"/latest/meta-data/iam/security-credentials/{iam_role}") 277 credentials: MetadataSecurityCredentials = response.json() 278 279 return ( 280 StaticCredentials( 281 region=document["region"], 282 access_key_id=credentials["AccessKeyId"], 283 secret_access_key=credentials["SecretAccessKey"], 284 session_token=credentials["Token"], 285 ), 286 datetime.datetime.fromisoformat(credentials["Expiration"]), 287 ) 288 289 @property 290 def signer(self) -> AwsRequestSigner: 291 if not self._signer: 292 raise RuntimeError( 293 f"{self.__class__.__name__} must be started before using", 294 ) 295 return self._signer
296 297 298__all__ = ( 299 "AbstractCredentials", 300 "ConfigCredentials", 301 "EnvironmentCredentials", 302 "MetadataCredentials", 303 "StaticCredentials", 304 "URLCredentials", 305 "collect_credentials", 306 "merge_credentials", 307)