faiss_rag_enterprise/llama_index/async_utils.py

111 lines
2.9 KiB
Python

"""Async utils."""
import asyncio
from itertools import zip_longest
from typing import Any, Coroutine, Iterable, List
def asyncio_module(show_progress: bool = False) -> Any:
if show_progress:
from tqdm.asyncio import tqdm_asyncio
module = tqdm_asyncio
else:
module = asyncio
return module
def run_async_tasks(
tasks: List[Coroutine],
show_progress: bool = False,
progress_bar_desc: str = "Running async tasks",
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
if show_progress:
try:
import nest_asyncio
from tqdm.asyncio import tqdm
# jupyter notebooks already have an event loop running
# we need to reuse it instead of creating a new one
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async def _tqdm_gather() -> List[Any]:
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
return tqdm_outputs
# run the operation w/o tqdm on hitting a fatal
# may occur in some environments where tqdm.asyncio
# is not supported
except Exception:
pass
async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)
outputs: List[Any] = asyncio.run(_gather())
return outputs
def chunks(iterable: Iterable, size: int) -> Iterable:
args = [iter(iterable)] * size
return zip_longest(*args, fillvalue=None)
async def batch_gather(
tasks: List[Coroutine], batch_size: int = 10, verbose: bool = False
) -> List[Any]:
output: List[Any] = []
for task_chunk in chunks(tasks, batch_size):
output_chunk = await asyncio.gather(*task_chunk)
output.extend(output_chunk)
if verbose:
print(f"Completed {len(output)} out of {len(tasks)} tasks")
return output
def get_asyncio_module(show_progress: bool = False) -> Any:
if show_progress:
from tqdm.asyncio import tqdm_asyncio
module = tqdm_asyncio
else:
module = asyncio
return module
DEFAULT_NUM_WORKERS = 4
async def run_jobs(
jobs: List[Coroutine],
show_progress: bool = False,
workers: int = DEFAULT_NUM_WORKERS,
) -> List[Any]:
"""Run jobs.
Args:
jobs (List[Coroutine]):
List of jobs to run.
show_progress (bool):
Whether to show progress bar.
Returns:
List[Any]:
List of results.
"""
asyncio_mod = get_asyncio_module(show_progress=show_progress)
semaphore = asyncio.Semaphore(workers)
async def worker(job: Coroutine) -> Any:
async with semaphore:
return await job
pool_jobs = [worker(job) for job in jobs]
return await asyncio_mod.gather(*pool_jobs)