"""Async utils.""" import asyncio from itertools import zip_longest from typing import Any, Coroutine, Iterable, List def asyncio_module(show_progress: bool = False) -> Any: if show_progress: from tqdm.asyncio import tqdm_asyncio module = tqdm_asyncio else: module = asyncio return module def run_async_tasks( tasks: List[Coroutine], show_progress: bool = False, progress_bar_desc: str = "Running async tasks", ) -> List[Any]: """Run a list of async tasks.""" tasks_to_execute: List[Any] = tasks if show_progress: try: import nest_asyncio from tqdm.asyncio import tqdm # jupyter notebooks already have an event loop running # we need to reuse it instead of creating a new one nest_asyncio.apply() loop = asyncio.get_event_loop() async def _tqdm_gather() -> List[Any]: return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc) tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather()) return tqdm_outputs # run the operation w/o tqdm on hitting a fatal # may occur in some environments where tqdm.asyncio # is not supported except Exception: pass async def _gather() -> List[Any]: return await asyncio.gather(*tasks_to_execute) outputs: List[Any] = asyncio.run(_gather()) return outputs def chunks(iterable: Iterable, size: int) -> Iterable: args = [iter(iterable)] * size return zip_longest(*args, fillvalue=None) async def batch_gather( tasks: List[Coroutine], batch_size: int = 10, verbose: bool = False ) -> List[Any]: output: List[Any] = [] for task_chunk in chunks(tasks, batch_size): output_chunk = await asyncio.gather(*task_chunk) output.extend(output_chunk) if verbose: print(f"Completed {len(output)} out of {len(tasks)} tasks") return output def get_asyncio_module(show_progress: bool = False) -> Any: if show_progress: from tqdm.asyncio import tqdm_asyncio module = tqdm_asyncio else: module = asyncio return module DEFAULT_NUM_WORKERS = 4 async def run_jobs( jobs: List[Coroutine], show_progress: bool = False, workers: int = DEFAULT_NUM_WORKERS, ) -> List[Any]: """Run jobs. Args: jobs (List[Coroutine]): List of jobs to run. show_progress (bool): Whether to show progress bar. Returns: List[Any]: List of results. """ asyncio_mod = get_asyncio_module(show_progress=show_progress) semaphore = asyncio.Semaphore(workers) async def worker(job: Coroutine) -> Any: async with semaphore: return await job pool_jobs = [worker(job) for job in jobs] return await asyncio_mod.gather(*pool_jobs)