126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
import copy
|
|
import json
|
|
from dataclasses import asdict, dataclass, field
|
|
from typing import List, Union
|
|
|
|
|
|
@dataclass
|
|
class DatasetInfo:
|
|
name: str
|
|
weight: float = 1.0 # sample weight in each collection
|
|
task_type: str = ''
|
|
tags: List[str] = field(default_factory=list)
|
|
args: dict = field(default_factory=dict)
|
|
hierarchy: List[str] = field(default_factory=list)
|
|
|
|
def get_data(self) -> dict:
|
|
from evalscope.benchmarks import Benchmark
|
|
|
|
benchmark_meta = Benchmark.get(self.name)
|
|
|
|
data_adapter = benchmark_meta.get_data_adapter(config=self.args)
|
|
data_dict = data_adapter.load()
|
|
prompts = data_adapter.gen_prompts(data_dict)
|
|
return prompts
|
|
|
|
|
|
def flatten_weight(collection: 'CollectionSchema', base_weight=1):
|
|
total_weight = sum(dataset.weight for dataset in collection.datasets)
|
|
for dataset in collection.datasets:
|
|
current_weight = dataset.weight / total_weight * base_weight
|
|
if isinstance(dataset, CollectionSchema):
|
|
flatten_weight(dataset, current_weight)
|
|
else:
|
|
dataset.weight = current_weight
|
|
|
|
|
|
def flatten_name(collection: 'CollectionSchema', parent_names=None):
|
|
if parent_names is None:
|
|
parent_names = []
|
|
current_names = parent_names + [collection.name]
|
|
for dataset in collection.datasets:
|
|
if isinstance(dataset, CollectionSchema):
|
|
flatten_name(dataset, current_names)
|
|
else:
|
|
dataset.hierarchy = current_names.copy()
|
|
|
|
|
|
def flatten_datasets(collection: 'CollectionSchema') -> List[DatasetInfo]:
|
|
flat_datasets = []
|
|
for dataset in collection.datasets:
|
|
if isinstance(dataset, CollectionSchema):
|
|
flat_datasets.extend(flatten_datasets(dataset))
|
|
else:
|
|
flat_datasets.append(dataset)
|
|
return flat_datasets
|
|
|
|
|
|
@dataclass
|
|
class CollectionSchema:
|
|
name: str
|
|
weight: float = 1.0
|
|
datasets: List[Union[DatasetInfo, 'CollectionSchema']] = field(default_factory=list)
|
|
|
|
def __str__(self):
|
|
return json.dumps(self.to_dict(), ensure_ascii=False, indent=4)
|
|
|
|
def to_dict(self):
|
|
return {
|
|
'name':
|
|
self.name,
|
|
'weight':
|
|
self.weight,
|
|
'datasets':
|
|
[asdict(dataset) if isinstance(dataset, DatasetInfo) else dataset.to_dict() for dataset in self.datasets],
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data):
|
|
instance = cls(name=data.get('name', ''), weight=data.get('weight', 1))
|
|
for dataset in data.get('datasets', []):
|
|
if 'datasets' in dataset:
|
|
instance.datasets.append(CollectionSchema.from_dict(dataset))
|
|
else:
|
|
instance.datasets.append(DatasetInfo(**dataset))
|
|
return instance
|
|
|
|
def dump_json(self, file_path):
|
|
d = self.to_dict()
|
|
with open(file_path, 'w') as f:
|
|
json.dump(d, f, ensure_ascii=False, indent=4)
|
|
|
|
@classmethod
|
|
def from_json(cls, file_path):
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
return cls.from_dict(data)
|
|
|
|
def flatten(self) -> List[DatasetInfo]:
|
|
collection = copy.deepcopy(self)
|
|
flatten_name(collection)
|
|
flatten_weight(collection)
|
|
return flatten_datasets(collection)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
schema = CollectionSchema(
|
|
name='reasoning',
|
|
datasets=[
|
|
CollectionSchema(name='english', datasets=[
|
|
DatasetInfo(name='arc', weight=1, tags=['en']),
|
|
]),
|
|
CollectionSchema(
|
|
name='chinese',
|
|
datasets=[DatasetInfo(name='ceval', weight=1, tags=['zh'], args={'subset_list': ['logic']})])
|
|
])
|
|
print(schema)
|
|
print(schema.flatten())
|
|
schema.dump_json('outputs/schema.json')
|
|
|
|
schema = CollectionSchema.from_json('outputs/schema.json')
|
|
print(schema)
|
|
# 打印扁平化后的结果
|
|
for dataset in schema.flatten():
|
|
print(f'Dataset: {dataset.name}')
|
|
print(f"Hierarchy: {' -> '.join(dataset.hierarchy)}")
|