282 lines
8.9 KiB
Python
282 lines
8.9 KiB
Python
"""Param tuner."""
|
|
|
|
|
|
import asyncio
|
|
from abc import abstractmethod
|
|
from copy import deepcopy
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field, PrivateAttr
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
|
|
class RunResult(BaseModel):
|
|
"""Run result."""
|
|
|
|
score: float
|
|
params: Dict[str, Any]
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Metadata.")
|
|
|
|
|
|
class TunedResult(BaseModel):
|
|
run_results: List[RunResult]
|
|
best_idx: int
|
|
|
|
@property
|
|
def best_run_result(self) -> RunResult:
|
|
"""Get best run result."""
|
|
return self.run_results[self.best_idx]
|
|
|
|
|
|
def generate_param_combinations(param_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""Generate parameter combinations."""
|
|
|
|
def _generate_param_combinations_helper(
|
|
param_dict: Dict[str, Any], curr_param_dict: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Helper function."""
|
|
if len(param_dict) == 0:
|
|
return [deepcopy(curr_param_dict)]
|
|
param_dict = deepcopy(param_dict)
|
|
param_name, param_vals = param_dict.popitem()
|
|
param_combinations = []
|
|
for param_val in param_vals:
|
|
curr_param_dict[param_name] = param_val
|
|
param_combinations.extend(
|
|
_generate_param_combinations_helper(param_dict, curr_param_dict)
|
|
)
|
|
return param_combinations
|
|
|
|
return _generate_param_combinations_helper(param_dict, {})
|
|
|
|
|
|
class BaseParamTuner(BaseModel):
|
|
"""Base param tuner."""
|
|
|
|
param_dict: Dict[str, Any] = Field(
|
|
..., description="A dictionary of parameters to iterate over."
|
|
)
|
|
fixed_param_dict: Dict[str, Any] = Field(
|
|
default_factory=dict,
|
|
description="A dictionary of fixed parameters passed to each job.",
|
|
)
|
|
show_progress: bool = False
|
|
|
|
@abstractmethod
|
|
def tune(self) -> TunedResult:
|
|
"""Tune parameters."""
|
|
|
|
async def atune(self) -> TunedResult:
|
|
"""Async Tune parameters.
|
|
|
|
Override if you implement a native async method.
|
|
|
|
"""
|
|
return self.tune()
|
|
|
|
|
|
class ParamTuner(BaseParamTuner):
|
|
"""Parameter tuner.
|
|
|
|
Args:
|
|
param_dict(Dict): A dictionary of parameters to iterate over.
|
|
Example param_dict:
|
|
{
|
|
"num_epochs": [10, 20],
|
|
"batch_size": [8, 16, 32],
|
|
}
|
|
fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
|
|
|
|
"""
|
|
|
|
param_fn: Callable[[Dict[str, Any]], RunResult] = Field(
|
|
..., description="Function to run with parameters."
|
|
)
|
|
|
|
def tune(self) -> TunedResult:
|
|
"""Run tuning."""
|
|
# each key in param_dict is a parameter to tune, each val
|
|
# is a list of values to try
|
|
# generate combinations of parameters from the param_dict
|
|
param_combinations = generate_param_combinations(self.param_dict)
|
|
|
|
# for each combination, run the job with the arguments
|
|
# in args_dict
|
|
|
|
combos_with_progress = enumerate(
|
|
get_tqdm_iterable(
|
|
param_combinations, self.show_progress, "Param combinations."
|
|
)
|
|
)
|
|
|
|
all_run_results = []
|
|
for idx, param_combination in combos_with_progress:
|
|
full_param_dict = {
|
|
**self.fixed_param_dict,
|
|
**param_combination,
|
|
}
|
|
run_result = self.param_fn(full_param_dict)
|
|
|
|
all_run_results.append(run_result)
|
|
|
|
# sort the results by score
|
|
sorted_run_results = sorted(
|
|
all_run_results, key=lambda x: x.score, reverse=True
|
|
)
|
|
|
|
return TunedResult(run_results=sorted_run_results, best_idx=0)
|
|
|
|
|
|
class AsyncParamTuner(BaseParamTuner):
|
|
"""Async Parameter tuner.
|
|
|
|
Args:
|
|
param_dict(Dict): A dictionary of parameters to iterate over.
|
|
Example param_dict:
|
|
{
|
|
"num_epochs": [10, 20],
|
|
"batch_size": [8, 16, 32],
|
|
}
|
|
fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
|
|
aparam_fn (Callable): An async function to run with parameters.
|
|
num_workers (int): Number of workers to use.
|
|
|
|
"""
|
|
|
|
aparam_fn: Callable[[Dict[str, Any]], Awaitable[RunResult]] = Field(
|
|
..., description="Async function to run with parameters."
|
|
)
|
|
num_workers: int = Field(2, description="Number of workers to use.")
|
|
|
|
_semaphore: asyncio.Semaphore = PrivateAttr()
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Init params."""
|
|
super().__init__(*args, **kwargs)
|
|
self._semaphore = asyncio.Semaphore(self.num_workers)
|
|
|
|
async def atune(self) -> TunedResult:
|
|
"""Run tuning."""
|
|
# each key in param_dict is a parameter to tune, each val
|
|
# is a list of values to try
|
|
# generate combinations of parameters from the param_dict
|
|
param_combinations = generate_param_combinations(self.param_dict)
|
|
|
|
# for each combination, run the job with the arguments
|
|
# in args_dict
|
|
|
|
async def aparam_fn_worker(
|
|
semaphore: asyncio.Semaphore,
|
|
full_param_dict: Dict[str, Any],
|
|
) -> RunResult:
|
|
"""Async param fn worker."""
|
|
async with semaphore:
|
|
return await self.aparam_fn(full_param_dict)
|
|
|
|
all_run_results = []
|
|
run_jobs = []
|
|
for param_combination in param_combinations:
|
|
full_param_dict = {
|
|
**self.fixed_param_dict,
|
|
**param_combination,
|
|
}
|
|
run_jobs.append(aparam_fn_worker(self._semaphore, full_param_dict))
|
|
# run_jobs.append(self.aparam_fn(full_param_dict))
|
|
|
|
if self.show_progress:
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
all_run_results = await tqdm_asyncio.gather(*run_jobs)
|
|
else:
|
|
all_run_results = await asyncio.gather(*run_jobs)
|
|
|
|
# sort the results by score
|
|
sorted_run_results = sorted(
|
|
all_run_results, key=lambda x: x.score, reverse=True
|
|
)
|
|
|
|
return TunedResult(run_results=sorted_run_results, best_idx=0)
|
|
|
|
def tune(self) -> TunedResult:
|
|
"""Run tuning."""
|
|
return asyncio.run(self.atune())
|
|
|
|
|
|
class RayTuneParamTuner(BaseParamTuner):
|
|
"""Parameter tuner powered by Ray Tune.
|
|
|
|
Args:
|
|
param_dict(Dict): A dictionary of parameters to iterate over.
|
|
Example param_dict:
|
|
{
|
|
"num_epochs": [10, 20],
|
|
"batch_size": [8, 16, 32],
|
|
}
|
|
fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
|
|
|
|
"""
|
|
|
|
param_fn: Callable[[Dict[str, Any]], RunResult] = Field(
|
|
..., description="Function to run with parameters."
|
|
)
|
|
|
|
run_config_dict: Optional[dict] = Field(
|
|
default=None, description="Run config dict for Ray Tune."
|
|
)
|
|
|
|
def tune(self) -> TunedResult:
|
|
"""Run tuning."""
|
|
from ray import tune
|
|
from ray.train import RunConfig
|
|
|
|
# convert every array in param_dict to a tune.grid_search
|
|
ray_param_dict = {}
|
|
for param_name, param_vals in self.param_dict.items():
|
|
ray_param_dict[param_name] = tune.grid_search(param_vals)
|
|
|
|
def param_fn_wrapper(
|
|
ray_param_dict: Dict, fixed_param_dict: Optional[Dict] = None
|
|
) -> Dict:
|
|
# need a wrapper to pass in parameters to tune + fixed params
|
|
fixed_param_dict = fixed_param_dict or {}
|
|
full_param_dict = {
|
|
**fixed_param_dict,
|
|
**ray_param_dict,
|
|
}
|
|
tuned_result = self.param_fn(full_param_dict)
|
|
# need to convert RunResult to dict to obey
|
|
# Ray Tune's API
|
|
return tuned_result.dict()
|
|
|
|
run_config = RunConfig(**self.run_config_dict) if self.run_config_dict else None
|
|
|
|
tuner = tune.Tuner(
|
|
tune.with_parameters(
|
|
param_fn_wrapper, fixed_param_dict=self.fixed_param_dict
|
|
),
|
|
param_space=ray_param_dict,
|
|
run_config=run_config,
|
|
)
|
|
|
|
results = tuner.fit()
|
|
all_run_results = []
|
|
for idx in range(len(results)):
|
|
result = results[idx]
|
|
# convert dict back to RunResult (reconstruct it with metadata)
|
|
# get the keys in RunResult, assign corresponding values in
|
|
# result.metrics to those keys
|
|
run_result = RunResult.parse_obj(result.metrics)
|
|
# add some more metadata to run_result (e.g. timestamp)
|
|
run_result.metadata["timestamp"] = (
|
|
result.metrics["timestamp"] if result.metrics else None
|
|
)
|
|
|
|
all_run_results.append(run_result)
|
|
|
|
# sort the results by score
|
|
sorted_run_results = sorted(
|
|
all_run_results, key=lambda x: x.score, reverse=True
|
|
)
|
|
|
|
return TunedResult(run_results=sorted_run_results, best_idx=0)
|