349 lines
14 KiB
Python
349 lines
14 KiB
Python
import argparse
|
|
import csv
|
|
import time
|
|
import warnings
|
|
from datetime import timedelta
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from colossalai.utils import get_current_device, set_seed
|
|
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
|
from llava.conversation import conv_templates
|
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
|
from llava.model.builder import load_pretrained_model
|
|
from llava.utils import disable_torch_init
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from tqdm import tqdm
|
|
|
|
from ..datasets.utils import IMG_EXTENSIONS, VID_EXTENSIONS
|
|
from .acceleration.llava.policies import LlavaLlamaForCausalLMPolicy, LlavaMistralForCausalLMPolicy
|
|
from .utils import PROMPTS, Timer, VideoTextDataset, collate_fn
|
|
|
|
disable_torch_init()
|
|
import transformers
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
|
class NoPaddingDistributedSampler(DistributedSampler):
|
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False):
|
|
super().__init__(
|
|
dataset=dataset, num_replicas=num_replicas, rank=rank, seed=seed, shuffle=False, drop_last=False
|
|
)
|
|
remainder = len(self.dataset) % self.num_replicas
|
|
if remainder > 0 and (self.rank + 1) - remainder <= 0:
|
|
# if the dataset is not divisible by num_replicas
|
|
# the remaining items will be allocated to the first n ranks
|
|
self.num_samples = len(self.dataset) // self.num_replicas + 1
|
|
else:
|
|
self.num_samples = len(self.dataset) // self.num_replicas
|
|
self.total_size = len(dataset)
|
|
|
|
def __iter__(self):
|
|
if self.shuffle:
|
|
# deterministically shuffle based on epoch and seed
|
|
g = torch.Generator()
|
|
g.manual_seed(self.seed + self.epoch)
|
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
|
else:
|
|
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
|
|
|
# remove tail of data to make it evenly divisible.
|
|
indices = indices[: self.total_size]
|
|
|
|
# subsample
|
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
assert len(indices) == self.num_samples
|
|
return iter(indices)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main(args):
|
|
# ======================================================
|
|
# 1. init environment
|
|
# ======================================================
|
|
# we set a very large timeout to avoid some processes exit early
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
set_seed(1024)
|
|
coordinator = DistCoordinator()
|
|
|
|
assert (
|
|
args.dp_size * args.tp_size == coordinator.world_size
|
|
), f"DP size {args.dp_size} * TP size {args.tp_size} must equal to world size {coordinator.world_size}"
|
|
|
|
mesh = ProcessGroupMesh(args.dp_size, args.tp_size)
|
|
dp_group = mesh.get_group_along_axis(0)
|
|
tp_group = mesh.get_group_along_axis(1)
|
|
|
|
# ======================================================
|
|
# 2. load model
|
|
# ======================================================
|
|
model_path = args.model_path
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
|
model_path=model_path,
|
|
model_base=None,
|
|
model_name=get_model_name_from_path(model_path),
|
|
device=get_current_device(),
|
|
torch_dtype=torch.float16,
|
|
attn_implementation="flash_attention_2" if args.flash_attention else "eager",
|
|
)
|
|
dist.barrier()
|
|
|
|
# ======================================================
|
|
# 3. Apply system optimization
|
|
# ======================================================
|
|
tp_size = dist.get_world_size(tp_group)
|
|
shard_config = ShardConfig(
|
|
tensor_parallel_process_group=tp_group if tp_size > 1 else None,
|
|
enable_tensor_parallelism=True if tp_size > 1 else False,
|
|
)
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
|
|
# check the model type
|
|
model_name = model.__class__.__name__
|
|
print(model_name)
|
|
if model_name == "LlavaLlamaForCausalLM":
|
|
model = shard_former.optimize(model, policy=LlavaLlamaForCausalLMPolicy())[0].cuda()
|
|
elif model_name == "LlavaMistralForCausalLM":
|
|
model = shard_former.optimize(model, policy=LlavaMistralForCausalLMPolicy())[0].cuda()
|
|
else:
|
|
print(f"The shardformer policy for {model_name} is not implemented, skip")
|
|
torch.cuda.empty_cache()
|
|
|
|
# ======================================================
|
|
# 4. Prepare dataloader
|
|
# ======================================================
|
|
# prepare prompt
|
|
query = PROMPTS[args.prompt]["text"]
|
|
if dist.get_rank() == 0:
|
|
print(f"Prompt: {query}")
|
|
|
|
if "text" in args.prompt:
|
|
|
|
def get_text_input_ids(text):
|
|
conv = conv_templates["chatml_direct"].copy()
|
|
query_text = query.format(text)
|
|
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query_text)
|
|
prompt = conv.get_prompt()
|
|
t = prompt.split("<image>")
|
|
prompt = t[0] + "<image>" * args.num_frames + t[1]
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
|
input_ids = input_ids.unsqueeze(0)
|
|
return input_ids
|
|
|
|
else:
|
|
conv = conv_templates["chatml_direct"].copy()
|
|
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query)
|
|
prompt = conv.get_prompt()
|
|
t = prompt.split("<image>")
|
|
prompt = t[0] + "<image>" * args.num_frames + t[1]
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
|
input_ids = input_ids.unsqueeze(0)
|
|
|
|
def get_text_input_ids(*args):
|
|
return input_ids
|
|
|
|
# build dataset
|
|
def transform(imgs):
|
|
imgs = process_images(imgs, image_processor, model.config)
|
|
imgs = imgs.to(dtype=torch.float16)
|
|
return imgs
|
|
|
|
dataset = VideoTextDataset(
|
|
args.input,
|
|
transform=transform,
|
|
num_frames=args.num_frames,
|
|
get_text_input_ids=get_text_input_ids,
|
|
resize=args.resize,
|
|
)
|
|
|
|
# make sure that the prompt type matches the data type
|
|
data_extension = "." + dataset.data["path"].iloc[0].split(".")[-1]
|
|
prompt_type = PROMPTS[args.prompt]["type"]
|
|
if prompt_type == "image":
|
|
assert (
|
|
data_extension.lower() in IMG_EXTENSIONS
|
|
), f"The prompt is suitable for an image dataset but the data is not image."
|
|
elif prompt_type == "video":
|
|
assert (
|
|
data_extension.lower() in VID_EXTENSIONS
|
|
), f"The prompt is suitable for a video dataset but the data is not video."
|
|
else:
|
|
raise ValueError(f"Found invalid prompt type {prompt_type}")
|
|
|
|
total_num_videos = len(dataset)
|
|
|
|
# build sampler
|
|
dp_rank = dist.get_rank(dp_group)
|
|
dp_size = dist.get_world_size(dp_group)
|
|
sampler = NoPaddingDistributedSampler(dataset, rank=dp_rank, num_replicas=dp_size)
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=args.bs,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
pin_memory=True,
|
|
prefetch_factor=args.prefetch_factor,
|
|
sampler=sampler,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
with open(args.input, "r") as f:
|
|
reader = csv.DictReader(f)
|
|
original_data = [row for row in reader]
|
|
headers = reader.fieldnames
|
|
|
|
# prepare output file reader
|
|
output_file = args.input.replace(".csv", "_caption.csv")
|
|
|
|
# ======================================================
|
|
# 5. generate captions
|
|
# ======================================================
|
|
if dist.get_rank() == 0:
|
|
pbar = tqdm(dataloader, position=dp_rank, desc=f"Data Parallel Rank {dist.get_rank(dp_group)}")
|
|
else:
|
|
pbar = dataloader
|
|
|
|
if args.profile:
|
|
encode_time = []
|
|
generate_time = []
|
|
output_length = []
|
|
total_time = []
|
|
|
|
results = []
|
|
for i, batch in enumerate(pbar):
|
|
# measure time
|
|
if args.profile:
|
|
torch.cuda.synchronize()
|
|
start_time = time.time()
|
|
|
|
video_files, frames, video_lengths, img_size_list, texts = batch
|
|
|
|
# encode the batch of inputs
|
|
with Timer() as encode_timer:
|
|
samples = []
|
|
for imgs, imgs_size, input_ids in zip(frames, img_size_list, texts):
|
|
imgs = imgs.cuda()
|
|
input_ids = input_ids.cuda()
|
|
_, _, _, _, inputs_embeds, _ = model.prepare_inputs_labels_for_multimodal(
|
|
input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
|
|
)
|
|
samples.append(inputs_embeds)
|
|
|
|
# padding
|
|
max_len = max([sample.shape[1] for sample in samples])
|
|
attention_mask = torch.tensor(
|
|
[[0] * (max_len - samples[i].shape[1]) + [1] * samples[i].shape[1] for i in range(len(samples))]
|
|
).to(model.device)
|
|
inputs_embeds = [
|
|
torch.cat(
|
|
[
|
|
torch.zeros(
|
|
(1, max_len - samples[i].shape[1], samples[i].shape[-1]),
|
|
device=model.device,
|
|
dtype=torch.float16,
|
|
),
|
|
samples[i],
|
|
],
|
|
dim=1,
|
|
)
|
|
for i in range(len(samples))
|
|
]
|
|
inputs_embeds = torch.cat(inputs_embeds, dim=0)
|
|
|
|
with Timer() as generate_timer:
|
|
output_ids = super(type(model), model).generate(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
do_sample=False,
|
|
max_new_tokens=args.max_tokens,
|
|
use_cache=True,
|
|
)
|
|
# skip warmup and add profiling data
|
|
if args.profile and i >= args.profile_warmup:
|
|
output_length.append(output_ids.size(0) * output_ids.size(1))
|
|
|
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
|
outputs = [output.replace("\n", " ").strip() for output in outputs]
|
|
|
|
# skip warmup and add profiling data
|
|
if args.profile and i >= args.profile_warmup:
|
|
# measure time
|
|
torch.cuda.synchronize()
|
|
time_taken = time.time() - start_time
|
|
|
|
total_time.append(time_taken)
|
|
encode_time.append(encode_timer.time_taken)
|
|
generate_time.append(generate_timer.time_taken)
|
|
|
|
for video_file, output_text, video_length in zip(video_files, outputs, video_lengths):
|
|
original_row = next(row for row in original_data if row["path"] == video_file)
|
|
original_row["text"] = output_text
|
|
original_row["num_frames"] = video_length
|
|
results.append(original_row)
|
|
|
|
# display profiling info
|
|
if args.profile:
|
|
print(output_length)
|
|
num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size
|
|
print(f"throughput (samples/s): {num_samples_after_warmup / sum(total_time)}")
|
|
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
|
|
print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}")
|
|
print(f"average number of tokens characters per sample: {sum(output_length) / num_samples_after_warmup}")
|
|
print(f"Max GPU allocated / GB: {torch.cuda.max_memory_allocated() / 1024**3}")
|
|
print(f"Max GPU reserved / GB: {torch.cuda.max_memory_reserved() / 1024**3}")
|
|
|
|
if dist.get_rank() == 0:
|
|
all_results = [None] * dist.get_world_size()
|
|
else:
|
|
all_results = None
|
|
dist.gather_object(results, all_results, dst=0)
|
|
|
|
if dist.get_rank() == 0:
|
|
all_results = [item for sublist in all_results if sublist is not None for item in sublist]
|
|
|
|
with open(output_file, "w", newline="") as f:
|
|
if "num_frames" not in headers:
|
|
writer = csv.DictWriter(f, fieldnames=headers + ["text", "num_frames"])
|
|
else:
|
|
writer = csv.DictWriter(f, fieldnames=headers + ["text"])
|
|
writer.writeheader()
|
|
writer.writerows(all_results)
|
|
|
|
print(f"Results saved to {output_file}")
|
|
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
|
parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.6-34b")
|
|
parser.add_argument("--prompt", type=str, default="video-f1-detail-3ex")
|
|
parser.add_argument("--resize", type=int, default=336)
|
|
parser.add_argument("--num-frames", type=int, default=1)
|
|
parser.add_argument("--max-tokens", type=int, default=300)
|
|
parser.add_argument("--bs", type=int, default=16)
|
|
parser.add_argument("--tp-size", type=int, default=2)
|
|
parser.add_argument("--dp-size", type=int, default=4)
|
|
parser.add_argument("--num-workers", type=int, default=8)
|
|
parser.add_argument("--prefetch-factor", type=int, default=8, help="Prefetch factor")
|
|
parser.add_argument(
|
|
"--flash-attention",
|
|
action="store_true",
|
|
help="Whether to use flash attention. You can turn on this flag for llama model and off for mistral model.",
|
|
)
|
|
# debug related
|
|
parser.add_argument("--profile", action="store_true")
|
|
parser.add_argument("--profile-warmup", type=int, default=1)
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|