"""Download.""" import json import os from pathlib import Path from typing import Any, Dict, List, Optional, Union import requests import tqdm from llama_index.download.module import LLAMA_HUB_URL from llama_index.download.utils import ( get_file_content, get_file_content_bytes, initialize_directory, ) LLAMA_DATASETS_LFS_URL = ( f"https://media.githubusercontent.com/media/run-llama/llama-datasets/main" ) LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = ( "https://github.com/run-llama/llama-datasets/tree/main" ) LLAMA_SOURCE_FILES_PATH = "source_files" DATASET_CLASS_FILENAME_REGISTRY = { "LabelledRagDataset": "rag_dataset.json", "LabeledRagDataset": "rag_dataset.json", "LabelledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json", "LabeledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json", "LabelledEvaluatorDataset": "evaluator_dataset.json", "LabeledEvaluatorDataset": "evaluator_dataset.json", } PATH_TYPE = Union[str, Path] def _resolve_dataset_file_name(class_name: str) -> str: """Resolve filename based on dataset class.""" try: return DATASET_CLASS_FILENAME_REGISTRY[class_name] except KeyError as err: raise ValueError("Invalid dataset filename.") from err def _get_source_files_list(source_tree_url: str, path: str) -> List[str]: """Get the list of source files to download.""" resp = requests.get(source_tree_url + path + "?recursive=1") payload = resp.json()["payload"] return [item["name"] for item in payload["tree"]["items"]] def get_dataset_info( local_dir_path: PATH_TYPE, remote_dir_path: PATH_TYPE, remote_source_dir_path: PATH_TYPE, dataset_class: str, refresh_cache: bool = False, library_path: str = "library.json", source_files_path: str = "source_files", disable_library_cache: bool = False, ) -> Dict: """Get dataset info.""" if isinstance(local_dir_path, str): local_dir_path = Path(local_dir_path) local_library_path = f"{local_dir_path}/{library_path}" dataset_id = None source_files = [] # Check cache first if not refresh_cache and os.path.exists(local_library_path): with open(local_library_path) as f: library = json.load(f) if dataset_class in library: dataset_id = library[dataset_class]["id"] source_files = library[dataset_class].get("source_files", []) # Fetch up-to-date library from remote repo if dataset_id not found if dataset_id is None: library_raw_content, _ = get_file_content( str(remote_dir_path), f"/{library_path}" ) library = json.loads(library_raw_content) if dataset_class not in library: raise ValueError("Loader class name not found in library") dataset_id = library[dataset_class]["id"] # get data card raw_card_content, _ = get_file_content( str(remote_dir_path), f"/{dataset_id}/card.json" ) card = json.loads(raw_card_content) dataset_class_name = card["className"] source_files = [] if dataset_class_name == "LabelledRagDataset": source_files = _get_source_files_list( str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}" ) # create cache dir if needed local_library_dir = os.path.dirname(local_library_path) if not disable_library_cache: if not os.path.exists(local_library_dir): os.makedirs(local_library_dir) # Update cache with open(local_library_path, "w") as f: f.write(library_raw_content) if dataset_id is None: raise ValueError("Dataset class name not found in library") return { "dataset_id": dataset_id, "dataset_class_name": dataset_class_name, "source_files": source_files, } def download_dataset_and_source_files( local_dir_path: PATH_TYPE, remote_lfs_dir_path: PATH_TYPE, source_files_dir_path: PATH_TYPE, dataset_id: str, dataset_class_name: str, source_files: List[str], refresh_cache: bool = False, base_file_name: str = "rag_dataset.json", override_path: bool = False, show_progress: bool = False, ) -> None: """Download dataset and source files.""" if isinstance(local_dir_path, str): local_dir_path = Path(local_dir_path) if override_path: module_path = str(local_dir_path) else: module_path = f"{local_dir_path}/{dataset_id}" if refresh_cache or not os.path.exists(module_path): os.makedirs(module_path, exist_ok=True) base_file_name = _resolve_dataset_file_name(dataset_class_name) dataset_raw_content, _ = get_file_content( str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}" ) with open(f"{module_path}/{base_file_name}", "w") as f: f.write(dataset_raw_content) # Get content of source files if dataset_class_name == "LabelledRagDataset": os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True) if show_progress: source_files_iterator = tqdm.tqdm(source_files) else: source_files_iterator = source_files for source_file in source_files_iterator: if ".pdf" in source_file: source_file_raw_content_bytes, _ = get_file_content_bytes( str(remote_lfs_dir_path), f"/{dataset_id}/{source_files_dir_path}/{source_file}", ) with open( f"{module_path}/{source_files_dir_path}/{source_file}", "wb" ) as f: f.write(source_file_raw_content_bytes) else: source_file_raw_content, _ = get_file_content( str(remote_lfs_dir_path), f"/{dataset_id}/{source_files_dir_path}/{source_file}", ) with open( f"{module_path}/{source_files_dir_path}/{source_file}", "w" ) as f: f.write(source_file_raw_content) def download_llama_dataset( dataset_class: str, llama_hub_url: str = LLAMA_HUB_URL, llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL, llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL, refresh_cache: bool = False, custom_dir: Optional[str] = None, custom_path: Optional[str] = None, source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH, library_path: str = "llama_datasets/library.json", disable_library_cache: bool = False, override_path: bool = False, show_progress: bool = False, ) -> Any: """ Download a module from LlamaHub. Can be a loader, tool, pack, or more. Args: loader_class: The name of the llama module class you want to download, such as `GmailOpenAIAgentPack`. refresh_cache: If true, the local cache will be skipped and the loader will be fetched directly from the remote repo. custom_dir: Custom dir name to download loader into (under parent folder). custom_path: Custom dirpath to download loader into. library_path: File name of the library file. use_gpt_index_import: If true, the loader files will use llama_index as the base dependency. By default (False), the loader files use llama_index as the base dependency. NOTE: this is a temporary workaround while we fully migrate all usages to llama_index. is_dataset: whether or not downloading a LlamaDataset Returns: A Loader, A Pack, An Agent, or A Dataset """ # create directory / get path dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir) # fetch info from library.json file dataset_info = get_dataset_info( local_dir_path=dirpath, remote_dir_path=llama_hub_url, remote_source_dir_path=llama_datasets_source_files_tree_url, dataset_class=dataset_class, refresh_cache=refresh_cache, library_path=library_path, disable_library_cache=disable_library_cache, ) dataset_id = dataset_info["dataset_id"] source_files = dataset_info["source_files"] dataset_class_name = dataset_info["dataset_class_name"] dataset_filename = _resolve_dataset_file_name(dataset_class_name) download_dataset_and_source_files( local_dir_path=dirpath, remote_lfs_dir_path=llama_datasets_lfs_url, source_files_dir_path=source_files_dirpath, dataset_id=dataset_id, dataset_class_name=dataset_class_name, source_files=source_files, refresh_cache=refresh_cache, override_path=override_path, show_progress=show_progress, ) if override_path: module_path = str(dirpath) else: module_path = f"{dirpath}/{dataset_id}" return ( f"{module_path}/{dataset_filename}", f"{module_path}/{LLAMA_SOURCE_FILES_PATH}", )