Source code for handtruck.client

  1import hashlib
  2import io
  3import logging
  4import os
  5import typing as t
  6from collections import deque
  7from contextlib import suppress
  8from dataclasses import dataclass
  9from functools import partial
 10from http import HTTPStatus
 11from itertools import chain
 12from mimetypes import guess_type
 13from mmap import PAGESIZE
 14from pathlib import Path
 15from tempfile import NamedTemporaryFile
 16
 17import anyio
 18import anyio.streams.memory
 19from anyiomisc import asyncbackoff
 20from aws_request_signer import UNSIGNED_PAYLOAD
 21from httpx import URL, HTTPError, AsyncClient, Response, QueryParams
 22
 23from ._async import generate_in_thread, run_in_thread, parallel_file_writer, ChunkSendStream
 24from ._xml import (
 25    AwsObjectMeta, create_complete_upload_request,
 26    parse_create_multipart_upload_id, parse_list_objects,
 27)
 28from .credentials import (
 29    AbstractCredentials, collect_credentials,
 30)
 31
 32log = logging.getLogger(__name__)
 33
 34CHUNK_SIZE = 2 ** 16
 35
 36DONE = object()
 37EMPTY_STR_HASH = hashlib.sha256(b"").hexdigest()
 38# 5MB
 39
 40PART_SIZE = 5 * 1024 * 1024
 41HeadersType = t.Union[t.Dict]
 42
 43DataType = t.Optional[t.Mapping[str, t.Any]]
 44RequestContent = t.Optional[t.Union[str, bytes, t.Iterable[bytes], t.AsyncIterable[bytes]]]
 45PrimitiveData = t.Optional[t.Union[str, int, float, bool]]
 46QueryParamTypes = t.Union[
 47    QueryParams,
 48    t.Mapping[str, t.Union[PrimitiveData, t.Sequence[PrimitiveData]]],
 49    t.List[t.Tuple[str, PrimitiveData]],
 50    t.Tuple[t.Tuple[str, PrimitiveData], ...],
 51    str,
 52    bytes,
 53]
 54
 55
 56@dataclass
 57class HEADERS:
 58    CONTENT_LENGTH = 'Content-Length'
 59    CONTENT_TYPE = 'Content-Type'
 60
 61
 62class AwsError(HTTPError):
 63    pass
 64
 65
 66class AwsUploadError(AwsError):
 67    pass
 68
 69
 70class AwsDownloadError(AwsError):
 71    pass
 72
 73
 74@run_in_thread
 75def concat_files(
 76    target_file: Path, files: t.List[t.IO[bytes]], buffer_size: int,
 77) -> None:
 78    with target_file.open("ab") as fp:
 79        for file in files:
 80            file.seek(0)
 81            while True:
 82                chunk = file.read(buffer_size)
 83                if not chunk:
 84                    break
 85                fp.write(chunk)
 86            file.close()
 87
 88
 89@run_in_thread
 90def write_from_start(
 91    file: io.BytesIO, chunk: bytes, range_start: int, pos: int,
 92) -> None:
 93    file.seek(pos - range_start)
 94    file.write(chunk)
 95
 96
 97@generate_in_thread
 98def gen_without_hash(
 99    stream: t.Iterable[bytes],
100) -> t.Generator[t.Tuple[None, bytes], None, None]:
101    for data in stream:
102        yield (None, data)
103
104
105@generate_in_thread
106def gen_with_hash(
107    stream: t.Iterable[bytes],
108) -> t.Generator[t.Tuple[str, bytes], None, None]:
109    for data in stream:
110        yield hashlib.sha256(data).hexdigest(), data
111
112
113def file_sender(
114    file_name: t.Union[str, Path], chunk_size: int = CHUNK_SIZE,
115) -> t.Iterable[bytes]:
116    with open(file_name, "rb") as fp:
117        while True:
118            data = fp.read(chunk_size)
119            if not data:
120                break
121            yield data
122
123
124async_file_sender = generate_in_thread(file_sender)
125
126MultiPart = tuple
127
[docs] 128class S3Client: 129 def __init__( 130 self, client: AsyncClient, url: t.Union[URL, str], 131 secret_access_key: t.Optional[str] = None, 132 access_key_id: t.Optional[str] = None, 133 session_token: t.Optional[str] = None, 134 region: str = "", 135 credentials: t.Optional[AbstractCredentials] = None, 136 ): 137 url = URL(url) 138 if credentials is None: 139 credentials = collect_credentials( 140 url=url, 141 access_key_id=access_key_id, 142 region=region, 143 secret_access_key=secret_access_key, 144 session_token=session_token, 145 ) 146 147 if not credentials: 148 raise ValueError( 149 f"Credentials {credentials!r} is incomplete", 150 ) 151 152 self._url = url 153 self._client = client 154 self._credentials = credentials 155 156 @property 157 def url(self) -> URL: 158 return self._url 159
[docs] 160 async def request( 161 self, method: str, path: str, 162 headers: t.Optional[HeadersType] = None, 163 params: t.Optional[QueryParams] = None, 164 content: t.Optional[RequestContent] = None, 165 content_sha256: t.Optional[str] = None, 166 **kwargs, 167 ) -> Response: 168 headers = self._prepare_headers(headers) 169 170 if content is not None and content_sha256 is None: 171 content_sha256 = UNSIGNED_PAYLOAD 172 173 url = (self._url.join(path)) 174 if params: 175 url = url.copy_merge_params(params) 176 177 headers = self._make_headers(headers) 178 headers.update( 179 self._credentials.signer.sign_with_headers( 180 method, str(url), headers=headers, content_hash=content_sha256, 181 ), 182 ) 183 return await self._client.request( 184 method, url, headers=headers, content=content, **kwargs, 185 )
186
[docs] 187 async def get(self, object_name: str, **kwargs) -> Response: 188 return await self.request("GET", object_name, **kwargs)
189
[docs] 190 async def head( 191 self, object_name: str, 192 content_sha256=EMPTY_STR_HASH, 193 **kwargs, 194 ) -> Response: 195 return await self.request( 196 "HEAD", object_name, content_sha256=content_sha256, **kwargs, 197 )
198
[docs] 199 async def delete( 200 self, object_name: str, 201 content_sha256=EMPTY_STR_HASH, 202 **kwargs, 203 ) -> Response: 204 return await self.request( 205 "DELETE", object_name, content_sha256=content_sha256, **kwargs, 206 )
207 208 @staticmethod 209 def _make_headers(headers: t.Optional[HeadersType]) -> dict: 210 headers = dict(headers or {}) 211 return headers 212 213 def _prepare_headers( 214 self, headers: t.Optional[HeadersType], 215 file_path: str = "", 216 ) -> dict: 217 headers = self._make_headers(headers) 218 219 if HEADERS.CONTENT_TYPE not in headers: 220 content_type = guess_type(file_path)[0] 221 if content_type is None: 222 content_type = "application/octet-stream" 223 224 headers[HEADERS.CONTENT_TYPE] = content_type 225 226 return headers 227
[docs] 228 async def put( 229 self, object_name: str, 230 content: RequestContent, 231 **kwargs, 232 ) -> Response: 233 return await self.request("PUT", object_name, content=content, **kwargs)
234
[docs] 235 async def post( 236 self, object_name: str, 237 content: RequestContent = None, 238 **kwargs, 239 ) -> Response: 240 return await self.request("POST", object_name, content=content, **kwargs)
241
[docs] 242 async def put_file( 243 self, object_name: t.Union[str, Path], 244 file_path: t.Union[str, Path], 245 *, headers: t.Optional[HeadersType] = None, 246 chunk_size: int = CHUNK_SIZE, content_sha256: t.Optional[str] = None, 247 ) -> Response: 248 249 headers = self._prepare_headers(headers, str(file_path)) 250 return await self.put( 251 str(object_name), 252 headers=headers, 253 content=async_file_sender( 254 file_path, 255 chunk_size=chunk_size, 256 ), 257 content_sha256=content_sha256, 258 )
259 260 @asyncbackoff( 261 None, None, 0, 262 max_tries=3, exceptions=(HTTPError,), 263 ) 264 async def _create_multipart_upload( 265 self, 266 object_name: str, 267 headers: t.Optional[HeadersType] = None, 268 ) -> str: 269 resp = await self.post( 270 object_name, 271 headers=headers, 272 params={"uploads": 1}, 273 content_sha256=EMPTY_STR_HASH, 274 ) 275 payload = resp.read() 276 if resp.status_code != HTTPStatus.OK: 277 raise AwsUploadError( 278 f"Wrong status code {resp.status_code} from s3 with message " 279 f"{payload.decode()}.", 280 ) 281 return parse_create_multipart_upload_id(payload) 282 283 @asyncbackoff( 284 None, None, 0, 285 max_tries=3, exceptions=(AwsUploadError, HTTPError), 286 ) 287 async def _complete_multipart_upload( 288 self, 289 upload_id: str, 290 object_name: str, 291 parts: t.List[t.Tuple[int, str]], 292 ) -> None: 293 complete_upload_request = create_complete_upload_request(parts) 294 resp = await self.post( 295 object_name, 296 headers={"Content-Type": "text/xml"}, 297 params={"uploadId": upload_id}, 298 content=complete_upload_request, 299 content_sha256=hashlib.sha256(complete_upload_request).hexdigest(), 300 ) 301 if resp.status_code != HTTPStatus.OK: 302 payload = resp.content 303 raise AwsUploadError( 304 f"Wrong status code {resp.status_code} from s3 with message " 305 f"{payload!r}.", 306 ) 307 308 async def _put_part( 309 self, 310 upload_id: str, 311 object_name: str, 312 part_no: int, 313 content: RequestContent, 314 content_sha256: str, 315 **kwargs, 316 ) -> str: 317 resp = await self.put( 318 object_name, 319 params={"partNumber": part_no, "uploadId": upload_id}, 320 content=content, 321 content_sha256=content_sha256, 322 **kwargs, 323 ) 324 payload = resp.content 325 if resp.status_code != HTTPStatus.OK: 326 raise AwsUploadError( 327 f"Wrong status code {resp.status_code} from s3 with message " 328 f"{payload!r}.", 329 ) 330 return resp.headers["Etag"].strip('"') 331 332 async def _part_uploader( 333 self, 334 upload_id: str, 335 object_name: str, 336 parts_stream: anyio.streams.memory.MemoryObjectReceiveStream[MultiPart], 337 results_queue: deque, 338 part_upload_tries: int, 339 **kwargs, 340 ) -> None: 341 backoff = asyncbackoff( 342 None, None, 343 max_tries=part_upload_tries, 344 exceptions=(HTTPError,), 345 ) 346 async for part_no, part_hash, part in parts_stream: 347 etag = await backoff(self._put_part)( 348 upload_id=upload_id, 349 object_name=object_name, 350 part_no=part_no, 351 content=part, 352 content_sha256=part_hash, 353 **kwargs, 354 ) 355 log.debug( 356 "Etag for part %d of %s is %s", part_no, upload_id, etag, 357 ) 358 results_queue.append((part_no, etag)) 359
[docs] 360 async def put_file_multipart( 361 self, 362 object_name: t.Union[str, Path], 363 file_path: t.Union[str, Path], 364 *, 365 headers: t.Optional[HeadersType] = None, 366 part_size: int = PART_SIZE, 367 workers_count: int = 1, 368 max_size: t.Optional[int] = None, 369 part_upload_tries: int = 3, 370 calculate_content_sha256: bool = True, 371 **kwargs, 372 ) -> None: 373 """ 374 Upload data from a file with multipart upload 375 376 object_name: key in s3 377 file_path: path to a file for upload 378 headers: additional headers, such as Content-Type 379 part_size: size of a chunk to send (recommended: >5Mb) 380 workers_count: count of coroutines for asyncronous parts uploading 381 max_size: maximum size of a queue with data to send (should be 382 at least `workers_count`) 383 part_upload_tries: how many times trying to put part to s3 before fail 384 calculate_content_sha256: whether to calculate sha256 hash of a part 385 for integrity purposes 386 """ 387 log.debug( 388 "Going to multipart upload %s to %s with part size %d", 389 file_path, object_name, part_size, 390 ) 391 await self.put_multipart( 392 object_name, 393 file_sender( 394 file_path, 395 chunk_size=part_size, 396 ), 397 headers=headers, 398 workers_count=workers_count, 399 max_size=max_size, 400 part_upload_tries=part_upload_tries, 401 calculate_content_sha256=calculate_content_sha256, 402 **kwargs, 403 )
404 405 async def _parts_generator( 406 self, gen: t.AsyncIterable[tuple], workers_count: int, parts_stream: anyio.streams.memory.MemoryObjectSendStream[MultiPart], 407 ) -> int: 408 part_no = 1 409 async with parts_stream: 410 async for part_hash, part in gen: 411 log.debug( 412 "Reading part %d (%d bytes)", part_no, len(part), 413 ) 414 await parts_stream.send((part_no, part_hash, part)) 415 part_no += 1 416 417 return part_no 418
[docs] 419 async def put_multipart( 420 self, 421 object_name: t.Union[str, Path], 422 content: t.Iterable[bytes], 423 *, 424 headers: t.Optional[HeadersType] = None, 425 workers_count: int = 1, 426 max_size: t.Optional[int] = None, 427 part_upload_tries: int = 3, 428 calculate_content_sha256: bool = True, 429 **kwargs, 430 ) -> None: 431 """ 432 Send data from iterable with multipart upload 433 434 object_name: key in s3 435 data: any iterable that returns chunks of bytes 436 headers: additional headers, such as Content-Type 437 workers_count: count of coroutines for asyncronous parts uploading 438 max_size: maximum size of a queue with data to send (should be 439 at least `workers_count`) 440 part_upload_tries: how many times trying to put part to s3 before fail 441 calculate_content_sha256: whether to calculate sha256 hash of a part 442 for integrity purposes 443 """ 444 if workers_count < 1: 445 raise ValueError( 446 f"Workers count should be > 0. Got {workers_count}", 447 ) 448 max_size = max_size or workers_count 449 450 upload_id = await self._create_multipart_upload( 451 str(object_name), 452 headers=headers, 453 ) 454 log.debug("Got upload id %s for %s", upload_id, object_name) 455 456 results_queue: deque = deque() 457 try: 458 async with anyio.create_task_group() as tg: 459 send_stream, receive_stream = anyio.create_memory_object_stream() 460 for wid in range(workers_count): 461 tg.start_soon(partial( 462 self._part_uploader, 463 upload_id, 464 str(object_name), 465 receive_stream.clone(), 466 results_queue, 467 part_upload_tries, 468 **kwargs, 469 ), name=f"put-worker-{upload_id}@{wid}") 470 # Get rid of our copy 471 receive_stream.close() 472 del receive_stream 473 474 if calculate_content_sha256: 475 gen = gen_with_hash(content) 476 else: 477 gen = gen_without_hash(content) 478 479 part_no = await self._parts_generator(gen, workers_count, send_stream) 480 except* Exception as excgroup: 481 for exc in excgroup.exceptions: 482 raise exc from None 483 484 log.debug( 485 "All parts (#%d) of %s are uploaded to %s", 486 part_no - 1, upload_id, object_name, 487 ) 488 489 # Parts should be in ascending order 490 parts = sorted(results_queue, key=lambda x: x[0]) 491 await self._complete_multipart_upload( 492 upload_id, str(object_name), parts, 493 )
494 495 async def _download_range( 496 self, 497 object_name: str, 498 writer: ChunkSendStream, 499 *, 500 etag: str, 501 pos: int, 502 range_start: int, 503 req_range_start: int, 504 req_range_end: int, 505 buffer_size: int, 506 headers: t.Optional[HeadersType] = None, 507 **kwargs, 508 ) -> None: 509 """ 510 Downloading range [req_range_start:req_range_end] to `file` 511 """ 512 log.debug( 513 "Downloading %s from %d to %d", 514 object_name, 515 req_range_start, 516 req_range_end, 517 ) 518 if not headers: 519 headers = {} 520 headers = headers.copy() 521 headers["Range"] = f"bytes={req_range_start}-{req_range_end}" 522 headers["If-Match"] = etag 523 524 resp = await self.get(object_name, headers=headers, **kwargs) 525 if resp.status_code not in (HTTPStatus.PARTIAL_CONTENT, HTTPStatus.OK): 526 raise AwsDownloadError( 527 f"Got wrong status code {resp.status_code} on range download " 528 f"of {object_name}", 529 ) 530 assert 'Content-Range' in resp.headers 531 assert resp.headers['Content-Range'].startswith(f"bytes {req_range_start}-{req_range_end}/") 532 # FIXME: Handle OK 533 # FIXME: Handle Content-Range being different from requested 534 pos = req_range_start 535 async for chunk in resp.aiter_bytes(buffer_size): 536 if not chunk: 537 break 538 await writer.send((pos, chunk)) 539 pos += len(chunk) 540 541 async def _download_worker( 542 self, 543 object_name: str, 544 writer: ChunkSendStream, 545 *, 546 etag: str, 547 range_step: int, 548 range_start: int, 549 range_end: int, 550 buffer_size: int, 551 range_get_tries: int = 3, 552 headers: t.Optional[HeadersType] = None, 553 **kwargs, 554 ) -> None: 555 """ 556 Downloads data in range `[range_start, range_end)` 557 with step `range_step` to file `file_path`. 558 Uses `etag` to make sure that file wasn't changed in the process. 559 """ 560 log.debug( 561 "Starting download worker for range [%d:%d]", 562 range_start, 563 range_end, 564 ) 565 async with writer: 566 backoff = asyncbackoff( 567 None, None, 568 max_tries=range_get_tries, 569 exceptions=(HTTPError,), 570 ) 571 req_range_end = range_start 572 for req_range_start in range(range_start, range_end, range_step): 573 req_range_end += range_step 574 if req_range_end > range_end: 575 req_range_end = range_end 576 await backoff(self._download_range)( 577 object_name, 578 writer, 579 etag=etag, 580 pos=(req_range_start - range_start), 581 range_start=range_start, 582 req_range_start=req_range_start, 583 req_range_end=req_range_end - 1, 584 buffer_size=buffer_size, 585 headers=headers, 586 **kwargs, 587 ) 588
[docs] 589 async def get_file_parallel( 590 self, 591 object_name: t.Union[str, Path], 592 file_path: t.Union[str, Path], 593 *, 594 headers: t.Optional[HeadersType] = None, 595 range_step: int = PART_SIZE, 596 workers_count: int = 1, 597 range_get_tries: int = 3, 598 buffer_size: int = PAGESIZE * 32, 599 **kwargs, 600 ) -> None: 601 """ 602 Download object in parallel with requests with Range. 603 If file will change while download is in progress - 604 error will be raised. 605 606 object_name: s3 key to download 607 file_path: target file path 608 headers: additional headers 609 range_step: how much data will be downloaded in single HTTP request 610 workers_count: count of parallel workers 611 range_get_tries: count of tries to download each range 612 buffer_size: size of a buffer for on the fly data 613 """ 614 file_path = Path(file_path) 615 resp = await self.head(str(object_name), headers=headers) 616 if resp.status_code != HTTPStatus.OK: 617 raise AwsDownloadError( 618 f"Got response for HEAD request for {object_name} " 619 f"of a wrong status {resp.status_code}", 620 ) 621 etag = resp.headers["Etag"] 622 file_size = int(resp.headers["Content-Length"]) 623 log.debug( 624 "Object's %s etag is %s and size is %d", 625 object_name, 626 etag, 627 file_size, 628 ) 629 630 worker_range_size = file_size // workers_count 631 range_end = 0 632 try: 633 try: 634 async with ( 635 await anyio.open_file(file_path, "w+b") as fp, 636 parallel_file_writer(fp) as pfw, 637 anyio.create_task_group() as tg, 638 ): 639 for range_start in range(0, file_size, worker_range_size): 640 range_end += worker_range_size 641 if range_end > file_size: 642 range_end = file_size 643 tg.start_soon(partial( 644 self._download_worker, 645 str(object_name), 646 await pfw.get_block(range_start, range_end), 647 buffer_size=buffer_size, 648 etag=etag, 649 headers=headers, 650 range_end=range_end, 651 range_get_tries=range_get_tries, 652 range_start=range_start, 653 range_step=range_step, 654 **kwargs, 655 ), name=f"get-worker@{range_start}") 656 except* Exception as excgroup: 657 # Unwrap and raise just one 658 for exc in excgroup.exceptions: 659 raise exc from None 660 661 except Exception: 662 log.exception( 663 "Error on file download. Removing possibly incomplete file %s", 664 file_path, 665 ) 666 with suppress(FileNotFoundError): 667 os.unlink(file_path) 668 raise
669
[docs] 670 async def list_objects_v2( 671 self, 672 object_name: t.Union[str, Path] = "/", 673 *, 674 bucket: t.Optional[str] = None, 675 prefix: t.Optional[t.Union[str, Path]] = None, 676 delimiter: t.Optional[str] = None, 677 max_keys: t.Optional[int] = None, 678 start_after: t.Optional[str] = None, 679 ) -> t.AsyncIterator[t.List[AwsObjectMeta]]: 680 """ 681 List objects in bucket. 682 683 Returns an iterator over lists of metadata objects, each corresponding 684 to an individual response result (typically limited to 1000 keys). 685 686 object_name: 687 path to listing endpoint, defaults to '/'; a `bucket` value is 688 prepended to this value if provided. 689 prefix: 690 limits the response to keys that begin with the specified 691 prefix 692 delimiter: a delimiter is a character you use to group keys 693 max_keys: maximum number of keys returned in the response 694 start_after: keys to start listing after 695 """ 696 697 params = { 698 "list-type": "2", 699 } 700 701 if prefix: 702 params["prefix"] = str(prefix) 703 704 if delimiter: 705 params["delimiter"] = delimiter 706 707 if max_keys: 708 params["max-keys"] = str(max_keys) 709 710 if start_after: 711 params["start-after"] = start_after 712 713 if bucket is not None: 714 object_name = f"/{bucket}" 715 716 while True: 717 resp = await self.get(str(object_name), params=params) 718 if resp.status_code != HTTPStatus.OK: 719 raise AwsDownloadError( 720 f"Got response with wrong status for GET request for " 721 f"{object_name} with prefix '{prefix}'", 722 ) 723 payload = resp.content 724 metadata, continuation_token = parse_list_objects(payload) 725 if not metadata: 726 break 727 yield metadata 728 if not continuation_token: 729 break 730 params["continuation-token"] = continuation_token