111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
"""Async utils."""
|
|
import asyncio
|
|
from itertools import zip_longest
|
|
from typing import Any, Coroutine, Iterable, List
|
|
|
|
|
|
def asyncio_module(show_progress: bool = False) -> Any:
|
|
if show_progress:
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
module = tqdm_asyncio
|
|
else:
|
|
module = asyncio
|
|
|
|
return module
|
|
|
|
|
|
def run_async_tasks(
|
|
tasks: List[Coroutine],
|
|
show_progress: bool = False,
|
|
progress_bar_desc: str = "Running async tasks",
|
|
) -> List[Any]:
|
|
"""Run a list of async tasks."""
|
|
tasks_to_execute: List[Any] = tasks
|
|
if show_progress:
|
|
try:
|
|
import nest_asyncio
|
|
from tqdm.asyncio import tqdm
|
|
|
|
# jupyter notebooks already have an event loop running
|
|
# we need to reuse it instead of creating a new one
|
|
nest_asyncio.apply()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
async def _tqdm_gather() -> List[Any]:
|
|
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
|
|
|
|
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
|
|
return tqdm_outputs
|
|
# run the operation w/o tqdm on hitting a fatal
|
|
# may occur in some environments where tqdm.asyncio
|
|
# is not supported
|
|
except Exception:
|
|
pass
|
|
|
|
async def _gather() -> List[Any]:
|
|
return await asyncio.gather(*tasks_to_execute)
|
|
|
|
outputs: List[Any] = asyncio.run(_gather())
|
|
return outputs
|
|
|
|
|
|
def chunks(iterable: Iterable, size: int) -> Iterable:
|
|
args = [iter(iterable)] * size
|
|
return zip_longest(*args, fillvalue=None)
|
|
|
|
|
|
async def batch_gather(
|
|
tasks: List[Coroutine], batch_size: int = 10, verbose: bool = False
|
|
) -> List[Any]:
|
|
output: List[Any] = []
|
|
for task_chunk in chunks(tasks, batch_size):
|
|
output_chunk = await asyncio.gather(*task_chunk)
|
|
output.extend(output_chunk)
|
|
if verbose:
|
|
print(f"Completed {len(output)} out of {len(tasks)} tasks")
|
|
return output
|
|
|
|
|
|
def get_asyncio_module(show_progress: bool = False) -> Any:
|
|
if show_progress:
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
module = tqdm_asyncio
|
|
else:
|
|
module = asyncio
|
|
|
|
return module
|
|
|
|
|
|
DEFAULT_NUM_WORKERS = 4
|
|
|
|
|
|
async def run_jobs(
|
|
jobs: List[Coroutine],
|
|
show_progress: bool = False,
|
|
workers: int = DEFAULT_NUM_WORKERS,
|
|
) -> List[Any]:
|
|
"""Run jobs.
|
|
|
|
Args:
|
|
jobs (List[Coroutine]):
|
|
List of jobs to run.
|
|
show_progress (bool):
|
|
Whether to show progress bar.
|
|
|
|
Returns:
|
|
List[Any]:
|
|
List of results.
|
|
"""
|
|
asyncio_mod = get_asyncio_module(show_progress=show_progress)
|
|
semaphore = asyncio.Semaphore(workers)
|
|
|
|
async def worker(job: Coroutine) -> Any:
|
|
async with semaphore:
|
|
return await job
|
|
|
|
pool_jobs = [worker(job) for job in jobs]
|
|
|
|
return await asyncio_mod.gather(*pool_jobs)
|