sglang_v0.5.2/vision_0.23.0/scripts/download_model_urls.py

42 lines
1.3 KiB
Python

import asyncio
import sys
from pathlib import Path
from time import perf_counter
from urllib.parse import urlsplit
import aiofiles
import aiohttp
from torchvision import models
from tqdm.asyncio import tqdm
async def main(download_root):
download_root.mkdir(parents=True, exist_ok=True)
urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))}
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
await tqdm.gather(*[download(download_root, session, url) for url in urls])
async def download(download_root, session, url):
response = await session.get(url, params=dict(source="ci"))
assert response.ok
file_name = Path(urlsplit(url).path).name
async with aiofiles.open(download_root / file_name, "wb") as f:
async for data in response.content.iter_any():
await f.write(data)
if __name__ == "__main__":
download_root = (
(Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve()
)
print(f"Downloading model weights to {download_root}")
start = perf_counter()
asyncio.get_event_loop().run_until_complete(main(download_root))
stop = perf_counter()
minutes, seconds = divmod(stop - start, 60)
print(f"Download took {minutes:2.0f}m {seconds:2.0f}s")