1import configparser
2import contextlib
3import datetime
4import logging
5import math
6import os
7from abc import ABC, abstractmethod
8from dataclasses import dataclass
9from functools import cached_property
10from pathlib import Path
11from typing import Any, List, Mapping, Optional, Tuple, TypedDict, Union
12
13import anyio
14from aws_request_signer import AwsRequestSigner
15from httpx import URL, AsyncClient
16
17
18log = logging.getLogger(__name__)
19
20
[docs]
21class AbstractCredentials(ABC):
22 @abstractmethod
23 def __bool__(self) -> bool:
24 ...
25
26 @property
27 @abstractmethod
28 def signer(self) -> AwsRequestSigner:
29 ...
30
31
[docs]
32@dataclass(frozen=True)
33class StaticCredentials(AbstractCredentials):
34 access_key_id: str = ""
35 secret_access_key: str = ""
36 session_token: Optional[str] = None
37 region: str = ""
38 service: str = "s3"
39
40 def __bool__(self) -> bool:
41 return all((self.access_key_id, self.secret_access_key))
42
43 def __repr__(self) -> str:
44 return (
45 f"{self.__class__.__name__}(access_key_id={self.access_key_id!r}, "
46 "secret_access_key="
47 f'{"******" if self.secret_access_key else None!r}, '
48 f"region={self.region!r}, service={self.service!r})"
49 )
50
[docs]
51 def as_dict(self) -> dict:
52 return {
53 "region": self.region,
54 "access_key_id": self.access_key_id,
55 "secret_access_key": self.secret_access_key,
56 "session_token": self.session_token,
57 "service": self.service,
58 }
59
[docs]
60 @cached_property
61 def signer(self) -> AwsRequestSigner:
62 return AwsRequestSigner(**self.as_dict())
63
64
[docs]
65class URLCredentials(StaticCredentials):
66 def __init__(
67 self, url: Union[str, URL], *, region: str = "", service: str = "s3",
68 ):
69 url = URL(url)
70 super().__init__(
71 access_key_id=url.username or "",
72 secret_access_key=url.password or "",
73 region=region, service=service,
74 )
75
76
[docs]
77class EnvironmentCredentials(StaticCredentials):
78 def __init__(self, region: str = "", service: str = "s3"):
79 super().__init__(
80 access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
81 secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
82 session_token=os.getenv("AWS_SESSION_TOKEN"),
83 region=os.getenv("AWS_DEFAULT_REGION", region),
84 service=service,
85 )
86
87
[docs]
88class ConfigCredentials(StaticCredentials):
89 DEFAULT_CREDENTIALS_PATH = Path.home() / ".aws" / "credentials"
90 DEFAULT_CONFIG_PATH = Path.home() / ".aws" / "config"
91
92 @staticmethod
93 def _parse_ini_section(path: Path, section: str) -> Mapping[str, str]:
94 conf = configparser.ConfigParser()
95 if not conf.read(path):
96 return {}
97
98 if section not in conf:
99 return {}
100
101 return conf[section]
102
103 def __init__(
104 self,
105 credentials_path: Union[str, Path, None] = None,
106 config_path: Union[str, Path, None] = DEFAULT_CONFIG_PATH, *,
107 region: str = "", service: str = "s3", profile: str = "auto",
108 ):
109 if credentials_path is None:
110 credentials_path = Path(
111 os.getenv(
112 "AWS_SHARED_CREDENTIALS_FILE",
113 self.DEFAULT_CREDENTIALS_PATH,
114 ),
115 )
116 credentials_path = Path(credentials_path)
117
118 if config_path is None:
119 config_path = Path(
120 os.getenv(
121 "AWS_SHARED_CONFIG_FILE",
122 self.DEFAULT_CONFIG_PATH,
123 ),
124 )
125 config_path = Path(config_path)
126
127 try:
128 credentials_paths_exists = (
129 credentials_path.exists() and config_path.exists()
130 )
131 except OSError:
132 credentials_paths_exists = False
133
134 if not credentials_paths_exists:
135 super().__init__(region=region, service=service)
136 return
137
138 if profile == "auto":
139 profile = os.getenv("AWS_PROFILE", "default")
140
141 section = self._parse_ini_section(credentials_path, profile)
142 access_key_id = section.get("aws_access_key_id", "")
143 secret_access_key = section.get("aws_secret_access_key", "")
144
145 section = self._parse_ini_section(config_path, profile)
146 region = section.get("region", "")
147
148 super().__init__(
149 access_key_id=access_key_id,
150 secret_access_key=secret_access_key,
151 region=region,
152 service=service,
153 )
154
155
156ENVIRONMENT_CREDENTIALS = EnvironmentCredentials()
157
158
[docs]
159def merge_credentials(*credentials: StaticCredentials) -> StaticCredentials:
160 result = {}
161 fields = (
162 "access_key_id", "secret_access_key",
163 "session_token", "region", "service",
164 )
165
166 for candidate in credentials:
167 for field in fields:
168 if field in result:
169 continue
170 value = getattr(candidate, field, None)
171 if not value:
172 continue
173 result[field] = value
174
175 return StaticCredentials(**result)
176
177
[docs]
178def collect_credentials(
179 *, url: Optional[URL] = None, **kwargs,
180) -> StaticCredentials:
181 credentials: List[StaticCredentials] = []
182 if kwargs:
183 credentials.append(StaticCredentials(**kwargs))
184 if url:
185 credentials.append(URLCredentials(url))
186 credentials.append(EnvironmentCredentials())
187 credentials.append(ConfigCredentials())
188 return merge_credentials(*credentials)
189
190
191class MetadataDocument(TypedDict, total=False):
192 """
193 Response example is:
194
195 {
196 "accountId" : "123123",
197 "architecture" : "x86_64",
198 "availabilityZone" : "us-east-1a",
199 "billingProducts" : null,
200 "devpayProductCodes" : null,
201 "marketplaceProductCodes" : null,
202 "imageId" : "ami-123123",
203 "instanceId" : "i-11232323",
204 "instanceType" : "t3a.micro",
205 "kernelId" : null,
206 "pendingTime" : "2023-06-13T18:18:58Z",
207 "privateIp" : "172.33.33.33",
208 "ramdiskId" : null,
209 "region" : "us-east-1",
210 "version" : "2017-09-30"
211 }
212 """
213 region: str
214
215
216class MetadataSecurityCredentials(TypedDict, total=False):
217 Code: str
218 Type: str
219 AccessKeyId: str
220 SecretAccessKey: str
221 Token: str
222 Expiration: str
223
224
296
297
298__all__ = (
299 "AbstractCredentials",
300 "ConfigCredentials",
301 "EnvironmentCredentials",
302 "MetadataCredentials",
303 "StaticCredentials",
304 "URLCredentials",
305 "collect_credentials",
306 "merge_credentials",
307)