1import hashlib
2import io
3import logging
4import os
5import typing as t
6from collections import deque
7from contextlib import suppress
8from dataclasses import dataclass
9from functools import partial
10from http import HTTPStatus
11from itertools import chain
12from mimetypes import guess_type
13from mmap import PAGESIZE
14from pathlib import Path
15from tempfile import NamedTemporaryFile
16
17import anyio
18import anyio.streams.memory
19from anyiomisc import asyncbackoff
20from aws_request_signer import UNSIGNED_PAYLOAD
21from httpx import URL, HTTPError, AsyncClient, Response, QueryParams
22
23from ._async import generate_in_thread, run_in_thread, parallel_file_writer, ChunkSendStream
24from ._xml import (
25 AwsObjectMeta, create_complete_upload_request,
26 parse_create_multipart_upload_id, parse_list_objects,
27)
28from .credentials import (
29 AbstractCredentials, collect_credentials,
30)
31
32log = logging.getLogger(__name__)
33
34CHUNK_SIZE = 2 ** 16
35
36DONE = object()
37EMPTY_STR_HASH = hashlib.sha256(b"").hexdigest()
38# 5MB
39
40PART_SIZE = 5 * 1024 * 1024
41HeadersType = t.Union[t.Dict]
42
43DataType = t.Optional[t.Mapping[str, t.Any]]
44RequestContent = t.Optional[t.Union[str, bytes, t.Iterable[bytes], t.AsyncIterable[bytes]]]
45PrimitiveData = t.Optional[t.Union[str, int, float, bool]]
46QueryParamTypes = t.Union[
47 QueryParams,
48 t.Mapping[str, t.Union[PrimitiveData, t.Sequence[PrimitiveData]]],
49 t.List[t.Tuple[str, PrimitiveData]],
50 t.Tuple[t.Tuple[str, PrimitiveData], ...],
51 str,
52 bytes,
53]
54
55
56@dataclass
57class HEADERS:
58 CONTENT_LENGTH = 'Content-Length'
59 CONTENT_TYPE = 'Content-Type'
60
61
62class AwsError(HTTPError):
63 pass
64
65
66class AwsUploadError(AwsError):
67 pass
68
69
70class AwsDownloadError(AwsError):
71 pass
72
73
74@run_in_thread
75def concat_files(
76 target_file: Path, files: t.List[t.IO[bytes]], buffer_size: int,
77) -> None:
78 with target_file.open("ab") as fp:
79 for file in files:
80 file.seek(0)
81 while True:
82 chunk = file.read(buffer_size)
83 if not chunk:
84 break
85 fp.write(chunk)
86 file.close()
87
88
89@run_in_thread
90def write_from_start(
91 file: io.BytesIO, chunk: bytes, range_start: int, pos: int,
92) -> None:
93 file.seek(pos - range_start)
94 file.write(chunk)
95
96
97@generate_in_thread
98def gen_without_hash(
99 stream: t.Iterable[bytes],
100) -> t.Generator[t.Tuple[None, bytes], None, None]:
101 for data in stream:
102 yield (None, data)
103
104
105@generate_in_thread
106def gen_with_hash(
107 stream: t.Iterable[bytes],
108) -> t.Generator[t.Tuple[str, bytes], None, None]:
109 for data in stream:
110 yield hashlib.sha256(data).hexdigest(), data
111
112
113def file_sender(
114 file_name: t.Union[str, Path], chunk_size: int = CHUNK_SIZE,
115) -> t.Iterable[bytes]:
116 with open(file_name, "rb") as fp:
117 while True:
118 data = fp.read(chunk_size)
119 if not data:
120 break
121 yield data
122
123
124async_file_sender = generate_in_thread(file_sender)
125
126MultiPart = tuple
127
[docs]
128class S3Client:
129 def __init__(
130 self, client: AsyncClient, url: t.Union[URL, str],
131 secret_access_key: t.Optional[str] = None,
132 access_key_id: t.Optional[str] = None,
133 session_token: t.Optional[str] = None,
134 region: str = "",
135 credentials: t.Optional[AbstractCredentials] = None,
136 ):
137 url = URL(url)
138 if credentials is None:
139 credentials = collect_credentials(
140 url=url,
141 access_key_id=access_key_id,
142 region=region,
143 secret_access_key=secret_access_key,
144 session_token=session_token,
145 )
146
147 if not credentials:
148 raise ValueError(
149 f"Credentials {credentials!r} is incomplete",
150 )
151
152 self._url = url
153 self._client = client
154 self._credentials = credentials
155
156 @property
157 def url(self) -> URL:
158 return self._url
159
[docs]
160 async def request(
161 self, method: str, path: str,
162 headers: t.Optional[HeadersType] = None,
163 params: t.Optional[QueryParams] = None,
164 content: t.Optional[RequestContent] = None,
165 content_sha256: t.Optional[str] = None,
166 **kwargs,
167 ) -> Response:
168 headers = self._prepare_headers(headers)
169
170 if content is not None and content_sha256 is None:
171 content_sha256 = UNSIGNED_PAYLOAD
172
173 url = (self._url.join(path))
174 if params:
175 url = url.copy_merge_params(params)
176
177 headers = self._make_headers(headers)
178 headers.update(
179 self._credentials.signer.sign_with_headers(
180 method, str(url), headers=headers, content_hash=content_sha256,
181 ),
182 )
183 return await self._client.request(
184 method, url, headers=headers, content=content, **kwargs,
185 )
186
[docs]
187 async def get(self, object_name: str, **kwargs) -> Response:
188 return await self.request("GET", object_name, **kwargs)
189
[docs]
190 async def head(
191 self, object_name: str,
192 content_sha256=EMPTY_STR_HASH,
193 **kwargs,
194 ) -> Response:
195 return await self.request(
196 "HEAD", object_name, content_sha256=content_sha256, **kwargs,
197 )
198
[docs]
199 async def delete(
200 self, object_name: str,
201 content_sha256=EMPTY_STR_HASH,
202 **kwargs,
203 ) -> Response:
204 return await self.request(
205 "DELETE", object_name, content_sha256=content_sha256, **kwargs,
206 )
207
208 @staticmethod
209 def _make_headers(headers: t.Optional[HeadersType]) -> dict:
210 headers = dict(headers or {})
211 return headers
212
213 def _prepare_headers(
214 self, headers: t.Optional[HeadersType],
215 file_path: str = "",
216 ) -> dict:
217 headers = self._make_headers(headers)
218
219 if HEADERS.CONTENT_TYPE not in headers:
220 content_type = guess_type(file_path)[0]
221 if content_type is None:
222 content_type = "application/octet-stream"
223
224 headers[HEADERS.CONTENT_TYPE] = content_type
225
226 return headers
227
[docs]
228 async def put(
229 self, object_name: str,
230 content: RequestContent,
231 **kwargs,
232 ) -> Response:
233 return await self.request("PUT", object_name, content=content, **kwargs)
234
[docs]
235 async def post(
236 self, object_name: str,
237 content: RequestContent = None,
238 **kwargs,
239 ) -> Response:
240 return await self.request("POST", object_name, content=content, **kwargs)
241
[docs]
242 async def put_file(
243 self, object_name: t.Union[str, Path],
244 file_path: t.Union[str, Path],
245 *, headers: t.Optional[HeadersType] = None,
246 chunk_size: int = CHUNK_SIZE, content_sha256: t.Optional[str] = None,
247 ) -> Response:
248
249 headers = self._prepare_headers(headers, str(file_path))
250 return await self.put(
251 str(object_name),
252 headers=headers,
253 content=async_file_sender(
254 file_path,
255 chunk_size=chunk_size,
256 ),
257 content_sha256=content_sha256,
258 )
259
260 @asyncbackoff(
261 None, None, 0,
262 max_tries=3, exceptions=(HTTPError,),
263 )
264 async def _create_multipart_upload(
265 self,
266 object_name: str,
267 headers: t.Optional[HeadersType] = None,
268 ) -> str:
269 resp = await self.post(
270 object_name,
271 headers=headers,
272 params={"uploads": 1},
273 content_sha256=EMPTY_STR_HASH,
274 )
275 payload = resp.read()
276 if resp.status_code != HTTPStatus.OK:
277 raise AwsUploadError(
278 f"Wrong status code {resp.status_code} from s3 with message "
279 f"{payload.decode()}.",
280 )
281 return parse_create_multipart_upload_id(payload)
282
283 @asyncbackoff(
284 None, None, 0,
285 max_tries=3, exceptions=(AwsUploadError, HTTPError),
286 )
287 async def _complete_multipart_upload(
288 self,
289 upload_id: str,
290 object_name: str,
291 parts: t.List[t.Tuple[int, str]],
292 ) -> None:
293 complete_upload_request = create_complete_upload_request(parts)
294 resp = await self.post(
295 object_name,
296 headers={"Content-Type": "text/xml"},
297 params={"uploadId": upload_id},
298 content=complete_upload_request,
299 content_sha256=hashlib.sha256(complete_upload_request).hexdigest(),
300 )
301 if resp.status_code != HTTPStatus.OK:
302 payload = resp.content
303 raise AwsUploadError(
304 f"Wrong status code {resp.status_code} from s3 with message "
305 f"{payload!r}.",
306 )
307
308 async def _put_part(
309 self,
310 upload_id: str,
311 object_name: str,
312 part_no: int,
313 content: RequestContent,
314 content_sha256: str,
315 **kwargs,
316 ) -> str:
317 resp = await self.put(
318 object_name,
319 params={"partNumber": part_no, "uploadId": upload_id},
320 content=content,
321 content_sha256=content_sha256,
322 **kwargs,
323 )
324 payload = resp.content
325 if resp.status_code != HTTPStatus.OK:
326 raise AwsUploadError(
327 f"Wrong status code {resp.status_code} from s3 with message "
328 f"{payload!r}.",
329 )
330 return resp.headers["Etag"].strip('"')
331
332 async def _part_uploader(
333 self,
334 upload_id: str,
335 object_name: str,
336 parts_stream: anyio.streams.memory.MemoryObjectReceiveStream[MultiPart],
337 results_queue: deque,
338 part_upload_tries: int,
339 **kwargs,
340 ) -> None:
341 backoff = asyncbackoff(
342 None, None,
343 max_tries=part_upload_tries,
344 exceptions=(HTTPError,),
345 )
346 async for part_no, part_hash, part in parts_stream:
347 etag = await backoff(self._put_part)(
348 upload_id=upload_id,
349 object_name=object_name,
350 part_no=part_no,
351 content=part,
352 content_sha256=part_hash,
353 **kwargs,
354 )
355 log.debug(
356 "Etag for part %d of %s is %s", part_no, upload_id, etag,
357 )
358 results_queue.append((part_no, etag))
359
[docs]
360 async def put_file_multipart(
361 self,
362 object_name: t.Union[str, Path],
363 file_path: t.Union[str, Path],
364 *,
365 headers: t.Optional[HeadersType] = None,
366 part_size: int = PART_SIZE,
367 workers_count: int = 1,
368 max_size: t.Optional[int] = None,
369 part_upload_tries: int = 3,
370 calculate_content_sha256: bool = True,
371 **kwargs,
372 ) -> None:
373 """
374 Upload data from a file with multipart upload
375
376 object_name: key in s3
377 file_path: path to a file for upload
378 headers: additional headers, such as Content-Type
379 part_size: size of a chunk to send (recommended: >5Mb)
380 workers_count: count of coroutines for asyncronous parts uploading
381 max_size: maximum size of a queue with data to send (should be
382 at least `workers_count`)
383 part_upload_tries: how many times trying to put part to s3 before fail
384 calculate_content_sha256: whether to calculate sha256 hash of a part
385 for integrity purposes
386 """
387 log.debug(
388 "Going to multipart upload %s to %s with part size %d",
389 file_path, object_name, part_size,
390 )
391 await self.put_multipart(
392 object_name,
393 file_sender(
394 file_path,
395 chunk_size=part_size,
396 ),
397 headers=headers,
398 workers_count=workers_count,
399 max_size=max_size,
400 part_upload_tries=part_upload_tries,
401 calculate_content_sha256=calculate_content_sha256,
402 **kwargs,
403 )
404
405 async def _parts_generator(
406 self, gen: t.AsyncIterable[tuple], workers_count: int, parts_stream: anyio.streams.memory.MemoryObjectSendStream[MultiPart],
407 ) -> int:
408 part_no = 1
409 async with parts_stream:
410 async for part_hash, part in gen:
411 log.debug(
412 "Reading part %d (%d bytes)", part_no, len(part),
413 )
414 await parts_stream.send((part_no, part_hash, part))
415 part_no += 1
416
417 return part_no
418
[docs]
419 async def put_multipart(
420 self,
421 object_name: t.Union[str, Path],
422 content: t.Iterable[bytes],
423 *,
424 headers: t.Optional[HeadersType] = None,
425 workers_count: int = 1,
426 max_size: t.Optional[int] = None,
427 part_upload_tries: int = 3,
428 calculate_content_sha256: bool = True,
429 **kwargs,
430 ) -> None:
431 """
432 Send data from iterable with multipart upload
433
434 object_name: key in s3
435 data: any iterable that returns chunks of bytes
436 headers: additional headers, such as Content-Type
437 workers_count: count of coroutines for asyncronous parts uploading
438 max_size: maximum size of a queue with data to send (should be
439 at least `workers_count`)
440 part_upload_tries: how many times trying to put part to s3 before fail
441 calculate_content_sha256: whether to calculate sha256 hash of a part
442 for integrity purposes
443 """
444 if workers_count < 1:
445 raise ValueError(
446 f"Workers count should be > 0. Got {workers_count}",
447 )
448 max_size = max_size or workers_count
449
450 upload_id = await self._create_multipart_upload(
451 str(object_name),
452 headers=headers,
453 )
454 log.debug("Got upload id %s for %s", upload_id, object_name)
455
456 results_queue: deque = deque()
457 try:
458 async with anyio.create_task_group() as tg:
459 send_stream, receive_stream = anyio.create_memory_object_stream()
460 for wid in range(workers_count):
461 tg.start_soon(partial(
462 self._part_uploader,
463 upload_id,
464 str(object_name),
465 receive_stream.clone(),
466 results_queue,
467 part_upload_tries,
468 **kwargs,
469 ), name=f"put-worker-{upload_id}@{wid}")
470 # Get rid of our copy
471 receive_stream.close()
472 del receive_stream
473
474 if calculate_content_sha256:
475 gen = gen_with_hash(content)
476 else:
477 gen = gen_without_hash(content)
478
479 part_no = await self._parts_generator(gen, workers_count, send_stream)
480 except* Exception as excgroup:
481 for exc in excgroup.exceptions:
482 raise exc from None
483
484 log.debug(
485 "All parts (#%d) of %s are uploaded to %s",
486 part_no - 1, upload_id, object_name,
487 )
488
489 # Parts should be in ascending order
490 parts = sorted(results_queue, key=lambda x: x[0])
491 await self._complete_multipart_upload(
492 upload_id, str(object_name), parts,
493 )
494
495 async def _download_range(
496 self,
497 object_name: str,
498 writer: ChunkSendStream,
499 *,
500 etag: str,
501 pos: int,
502 range_start: int,
503 req_range_start: int,
504 req_range_end: int,
505 buffer_size: int,
506 headers: t.Optional[HeadersType] = None,
507 **kwargs,
508 ) -> None:
509 """
510 Downloading range [req_range_start:req_range_end] to `file`
511 """
512 log.debug(
513 "Downloading %s from %d to %d",
514 object_name,
515 req_range_start,
516 req_range_end,
517 )
518 if not headers:
519 headers = {}
520 headers = headers.copy()
521 headers["Range"] = f"bytes={req_range_start}-{req_range_end}"
522 headers["If-Match"] = etag
523
524 resp = await self.get(object_name, headers=headers, **kwargs)
525 if resp.status_code not in (HTTPStatus.PARTIAL_CONTENT, HTTPStatus.OK):
526 raise AwsDownloadError(
527 f"Got wrong status code {resp.status_code} on range download "
528 f"of {object_name}",
529 )
530 assert 'Content-Range' in resp.headers
531 assert resp.headers['Content-Range'].startswith(f"bytes {req_range_start}-{req_range_end}/")
532 # FIXME: Handle OK
533 # FIXME: Handle Content-Range being different from requested
534 pos = req_range_start
535 async for chunk in resp.aiter_bytes(buffer_size):
536 if not chunk:
537 break
538 await writer.send((pos, chunk))
539 pos += len(chunk)
540
541 async def _download_worker(
542 self,
543 object_name: str,
544 writer: ChunkSendStream,
545 *,
546 etag: str,
547 range_step: int,
548 range_start: int,
549 range_end: int,
550 buffer_size: int,
551 range_get_tries: int = 3,
552 headers: t.Optional[HeadersType] = None,
553 **kwargs,
554 ) -> None:
555 """
556 Downloads data in range `[range_start, range_end)`
557 with step `range_step` to file `file_path`.
558 Uses `etag` to make sure that file wasn't changed in the process.
559 """
560 log.debug(
561 "Starting download worker for range [%d:%d]",
562 range_start,
563 range_end,
564 )
565 async with writer:
566 backoff = asyncbackoff(
567 None, None,
568 max_tries=range_get_tries,
569 exceptions=(HTTPError,),
570 )
571 req_range_end = range_start
572 for req_range_start in range(range_start, range_end, range_step):
573 req_range_end += range_step
574 if req_range_end > range_end:
575 req_range_end = range_end
576 await backoff(self._download_range)(
577 object_name,
578 writer,
579 etag=etag,
580 pos=(req_range_start - range_start),
581 range_start=range_start,
582 req_range_start=req_range_start,
583 req_range_end=req_range_end - 1,
584 buffer_size=buffer_size,
585 headers=headers,
586 **kwargs,
587 )
588
[docs]
589 async def get_file_parallel(
590 self,
591 object_name: t.Union[str, Path],
592 file_path: t.Union[str, Path],
593 *,
594 headers: t.Optional[HeadersType] = None,
595 range_step: int = PART_SIZE,
596 workers_count: int = 1,
597 range_get_tries: int = 3,
598 buffer_size: int = PAGESIZE * 32,
599 **kwargs,
600 ) -> None:
601 """
602 Download object in parallel with requests with Range.
603 If file will change while download is in progress -
604 error will be raised.
605
606 object_name: s3 key to download
607 file_path: target file path
608 headers: additional headers
609 range_step: how much data will be downloaded in single HTTP request
610 workers_count: count of parallel workers
611 range_get_tries: count of tries to download each range
612 buffer_size: size of a buffer for on the fly data
613 """
614 file_path = Path(file_path)
615 resp = await self.head(str(object_name), headers=headers)
616 if resp.status_code != HTTPStatus.OK:
617 raise AwsDownloadError(
618 f"Got response for HEAD request for {object_name} "
619 f"of a wrong status {resp.status_code}",
620 )
621 etag = resp.headers["Etag"]
622 file_size = int(resp.headers["Content-Length"])
623 log.debug(
624 "Object's %s etag is %s and size is %d",
625 object_name,
626 etag,
627 file_size,
628 )
629
630 worker_range_size = file_size // workers_count
631 range_end = 0
632 try:
633 try:
634 async with (
635 await anyio.open_file(file_path, "w+b") as fp,
636 parallel_file_writer(fp) as pfw,
637 anyio.create_task_group() as tg,
638 ):
639 for range_start in range(0, file_size, worker_range_size):
640 range_end += worker_range_size
641 if range_end > file_size:
642 range_end = file_size
643 tg.start_soon(partial(
644 self._download_worker,
645 str(object_name),
646 await pfw.get_block(range_start, range_end),
647 buffer_size=buffer_size,
648 etag=etag,
649 headers=headers,
650 range_end=range_end,
651 range_get_tries=range_get_tries,
652 range_start=range_start,
653 range_step=range_step,
654 **kwargs,
655 ), name=f"get-worker@{range_start}")
656 except* Exception as excgroup:
657 # Unwrap and raise just one
658 for exc in excgroup.exceptions:
659 raise exc from None
660
661 except Exception:
662 log.exception(
663 "Error on file download. Removing possibly incomplete file %s",
664 file_path,
665 )
666 with suppress(FileNotFoundError):
667 os.unlink(file_path)
668 raise
669
[docs]
670 async def list_objects_v2(
671 self,
672 object_name: t.Union[str, Path] = "/",
673 *,
674 bucket: t.Optional[str] = None,
675 prefix: t.Optional[t.Union[str, Path]] = None,
676 delimiter: t.Optional[str] = None,
677 max_keys: t.Optional[int] = None,
678 start_after: t.Optional[str] = None,
679 ) -> t.AsyncIterator[t.List[AwsObjectMeta]]:
680 """
681 List objects in bucket.
682
683 Returns an iterator over lists of metadata objects, each corresponding
684 to an individual response result (typically limited to 1000 keys).
685
686 object_name:
687 path to listing endpoint, defaults to '/'; a `bucket` value is
688 prepended to this value if provided.
689 prefix:
690 limits the response to keys that begin with the specified
691 prefix
692 delimiter: a delimiter is a character you use to group keys
693 max_keys: maximum number of keys returned in the response
694 start_after: keys to start listing after
695 """
696
697 params = {
698 "list-type": "2",
699 }
700
701 if prefix:
702 params["prefix"] = str(prefix)
703
704 if delimiter:
705 params["delimiter"] = delimiter
706
707 if max_keys:
708 params["max-keys"] = str(max_keys)
709
710 if start_after:
711 params["start-after"] = start_after
712
713 if bucket is not None:
714 object_name = f"/{bucket}"
715
716 while True:
717 resp = await self.get(str(object_name), params=params)
718 if resp.status_code != HTTPStatus.OK:
719 raise AwsDownloadError(
720 f"Got response with wrong status for GET request for "
721 f"{object_name} with prefix '{prefix}'",
722 )
723 payload = resp.content
724 metadata, continuation_token = parse_list_objects(payload)
725 if not metadata:
726 break
727 yield metadata
728 if not continuation_token:
729 break
730 params["continuation-token"] = continuation_token