263 lines
9.0 KiB
Python
263 lines
9.0 KiB
Python
"""Download."""
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import requests
|
|
import tqdm
|
|
|
|
from llama_index.download.module import LLAMA_HUB_URL
|
|
from llama_index.download.utils import (
|
|
get_file_content,
|
|
get_file_content_bytes,
|
|
initialize_directory,
|
|
)
|
|
|
|
LLAMA_DATASETS_LFS_URL = (
|
|
f"https://media.githubusercontent.com/media/run-llama/llama-datasets/main"
|
|
)
|
|
|
|
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = (
|
|
"https://github.com/run-llama/llama-datasets/tree/main"
|
|
)
|
|
LLAMA_SOURCE_FILES_PATH = "source_files"
|
|
|
|
DATASET_CLASS_FILENAME_REGISTRY = {
|
|
"LabelledRagDataset": "rag_dataset.json",
|
|
"LabeledRagDataset": "rag_dataset.json",
|
|
"LabelledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
|
|
"LabeledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
|
|
"LabelledEvaluatorDataset": "evaluator_dataset.json",
|
|
"LabeledEvaluatorDataset": "evaluator_dataset.json",
|
|
}
|
|
|
|
|
|
PATH_TYPE = Union[str, Path]
|
|
|
|
|
|
def _resolve_dataset_file_name(class_name: str) -> str:
|
|
"""Resolve filename based on dataset class."""
|
|
try:
|
|
return DATASET_CLASS_FILENAME_REGISTRY[class_name]
|
|
except KeyError as err:
|
|
raise ValueError("Invalid dataset filename.") from err
|
|
|
|
|
|
def _get_source_files_list(source_tree_url: str, path: str) -> List[str]:
|
|
"""Get the list of source files to download."""
|
|
resp = requests.get(source_tree_url + path + "?recursive=1")
|
|
payload = resp.json()["payload"]
|
|
return [item["name"] for item in payload["tree"]["items"]]
|
|
|
|
|
|
def get_dataset_info(
|
|
local_dir_path: PATH_TYPE,
|
|
remote_dir_path: PATH_TYPE,
|
|
remote_source_dir_path: PATH_TYPE,
|
|
dataset_class: str,
|
|
refresh_cache: bool = False,
|
|
library_path: str = "library.json",
|
|
source_files_path: str = "source_files",
|
|
disable_library_cache: bool = False,
|
|
) -> Dict:
|
|
"""Get dataset info."""
|
|
if isinstance(local_dir_path, str):
|
|
local_dir_path = Path(local_dir_path)
|
|
|
|
local_library_path = f"{local_dir_path}/{library_path}"
|
|
dataset_id = None
|
|
source_files = []
|
|
|
|
# Check cache first
|
|
if not refresh_cache and os.path.exists(local_library_path):
|
|
with open(local_library_path) as f:
|
|
library = json.load(f)
|
|
if dataset_class in library:
|
|
dataset_id = library[dataset_class]["id"]
|
|
source_files = library[dataset_class].get("source_files", [])
|
|
|
|
# Fetch up-to-date library from remote repo if dataset_id not found
|
|
if dataset_id is None:
|
|
library_raw_content, _ = get_file_content(
|
|
str(remote_dir_path), f"/{library_path}"
|
|
)
|
|
library = json.loads(library_raw_content)
|
|
if dataset_class not in library:
|
|
raise ValueError("Loader class name not found in library")
|
|
|
|
dataset_id = library[dataset_class]["id"]
|
|
|
|
# get data card
|
|
raw_card_content, _ = get_file_content(
|
|
str(remote_dir_path), f"/{dataset_id}/card.json"
|
|
)
|
|
card = json.loads(raw_card_content)
|
|
dataset_class_name = card["className"]
|
|
|
|
source_files = []
|
|
if dataset_class_name == "LabelledRagDataset":
|
|
source_files = _get_source_files_list(
|
|
str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}"
|
|
)
|
|
|
|
# create cache dir if needed
|
|
local_library_dir = os.path.dirname(local_library_path)
|
|
if not disable_library_cache:
|
|
if not os.path.exists(local_library_dir):
|
|
os.makedirs(local_library_dir)
|
|
|
|
# Update cache
|
|
with open(local_library_path, "w") as f:
|
|
f.write(library_raw_content)
|
|
|
|
if dataset_id is None:
|
|
raise ValueError("Dataset class name not found in library")
|
|
|
|
return {
|
|
"dataset_id": dataset_id,
|
|
"dataset_class_name": dataset_class_name,
|
|
"source_files": source_files,
|
|
}
|
|
|
|
|
|
def download_dataset_and_source_files(
|
|
local_dir_path: PATH_TYPE,
|
|
remote_lfs_dir_path: PATH_TYPE,
|
|
source_files_dir_path: PATH_TYPE,
|
|
dataset_id: str,
|
|
dataset_class_name: str,
|
|
source_files: List[str],
|
|
refresh_cache: bool = False,
|
|
base_file_name: str = "rag_dataset.json",
|
|
override_path: bool = False,
|
|
show_progress: bool = False,
|
|
) -> None:
|
|
"""Download dataset and source files."""
|
|
if isinstance(local_dir_path, str):
|
|
local_dir_path = Path(local_dir_path)
|
|
|
|
if override_path:
|
|
module_path = str(local_dir_path)
|
|
else:
|
|
module_path = f"{local_dir_path}/{dataset_id}"
|
|
|
|
if refresh_cache or not os.path.exists(module_path):
|
|
os.makedirs(module_path, exist_ok=True)
|
|
|
|
base_file_name = _resolve_dataset_file_name(dataset_class_name)
|
|
|
|
dataset_raw_content, _ = get_file_content(
|
|
str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}"
|
|
)
|
|
|
|
with open(f"{module_path}/{base_file_name}", "w") as f:
|
|
f.write(dataset_raw_content)
|
|
|
|
# Get content of source files
|
|
if dataset_class_name == "LabelledRagDataset":
|
|
os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True)
|
|
if show_progress:
|
|
source_files_iterator = tqdm.tqdm(source_files)
|
|
else:
|
|
source_files_iterator = source_files
|
|
for source_file in source_files_iterator:
|
|
if ".pdf" in source_file:
|
|
source_file_raw_content_bytes, _ = get_file_content_bytes(
|
|
str(remote_lfs_dir_path),
|
|
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
|
|
)
|
|
with open(
|
|
f"{module_path}/{source_files_dir_path}/{source_file}", "wb"
|
|
) as f:
|
|
f.write(source_file_raw_content_bytes)
|
|
else:
|
|
source_file_raw_content, _ = get_file_content(
|
|
str(remote_lfs_dir_path),
|
|
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
|
|
)
|
|
with open(
|
|
f"{module_path}/{source_files_dir_path}/{source_file}", "w"
|
|
) as f:
|
|
f.write(source_file_raw_content)
|
|
|
|
|
|
def download_llama_dataset(
|
|
dataset_class: str,
|
|
llama_hub_url: str = LLAMA_HUB_URL,
|
|
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
|
|
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
|
|
refresh_cache: bool = False,
|
|
custom_dir: Optional[str] = None,
|
|
custom_path: Optional[str] = None,
|
|
source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH,
|
|
library_path: str = "llama_datasets/library.json",
|
|
disable_library_cache: bool = False,
|
|
override_path: bool = False,
|
|
show_progress: bool = False,
|
|
) -> Any:
|
|
"""
|
|
Download a module from LlamaHub.
|
|
|
|
Can be a loader, tool, pack, or more.
|
|
|
|
Args:
|
|
loader_class: The name of the llama module class you want to download,
|
|
such as `GmailOpenAIAgentPack`.
|
|
refresh_cache: If true, the local cache will be skipped and the
|
|
loader will be fetched directly from the remote repo.
|
|
custom_dir: Custom dir name to download loader into (under parent folder).
|
|
custom_path: Custom dirpath to download loader into.
|
|
library_path: File name of the library file.
|
|
use_gpt_index_import: If true, the loader files will use
|
|
llama_index as the base dependency. By default (False),
|
|
the loader files use llama_index as the base dependency.
|
|
NOTE: this is a temporary workaround while we fully migrate all usages
|
|
to llama_index.
|
|
is_dataset: whether or not downloading a LlamaDataset
|
|
|
|
Returns:
|
|
A Loader, A Pack, An Agent, or A Dataset
|
|
"""
|
|
# create directory / get path
|
|
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
|
|
|
|
# fetch info from library.json file
|
|
dataset_info = get_dataset_info(
|
|
local_dir_path=dirpath,
|
|
remote_dir_path=llama_hub_url,
|
|
remote_source_dir_path=llama_datasets_source_files_tree_url,
|
|
dataset_class=dataset_class,
|
|
refresh_cache=refresh_cache,
|
|
library_path=library_path,
|
|
disable_library_cache=disable_library_cache,
|
|
)
|
|
dataset_id = dataset_info["dataset_id"]
|
|
source_files = dataset_info["source_files"]
|
|
dataset_class_name = dataset_info["dataset_class_name"]
|
|
|
|
dataset_filename = _resolve_dataset_file_name(dataset_class_name)
|
|
|
|
download_dataset_and_source_files(
|
|
local_dir_path=dirpath,
|
|
remote_lfs_dir_path=llama_datasets_lfs_url,
|
|
source_files_dir_path=source_files_dirpath,
|
|
dataset_id=dataset_id,
|
|
dataset_class_name=dataset_class_name,
|
|
source_files=source_files,
|
|
refresh_cache=refresh_cache,
|
|
override_path=override_path,
|
|
show_progress=show_progress,
|
|
)
|
|
|
|
if override_path:
|
|
module_path = str(dirpath)
|
|
else:
|
|
module_path = f"{dirpath}/{dataset_id}"
|
|
|
|
return (
|
|
f"{module_path}/{dataset_filename}",
|
|
f"{module_path}/{LLAMA_SOURCE_FILES_PATH}",
|
|
)
|