Skip to content

Commit 8ffb37e

Browse files
jeanachoijeanachoi
andauthored
Move downloads to beginning of quantize script (#43)
* Move downloads to beginning of quantize script * address comments * simplify download --------- Co-authored-by: jeanachoi <jeanac@nvidia.com>
1 parent f437c22 commit 8ffb37e

1 file changed

Lines changed: 48 additions & 24 deletions

File tree

scripts/quantize.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,31 @@
4949
"""
5050

5151
import os
52+
import shlex
53+
import shutil
54+
import subprocess
5255
from typing import Annotated, Literal
5356

57+
_MINIMUM_HF_CLI_VERSION = "1.3.5"
5458

55-
def init():
56-
import logging
57-
import warnings
5859

60+
def init():
5961
os.environ["TOKENIZERS_PARALLELISM"] = "false"
6062
os.environ["TORCH_LOGS"] = "-dynamo"
61-
os.environ["LOGURU_LEVEL"] = "ERROR"
63+
64+
import logging
65+
import warnings
6266

6367
warnings.filterwarnings("ignore")
6468
logging.basicConfig(level=logging.ERROR)
6569

6670

6771
init()
72+
print("Loading dependencies...")
6873

6974

7075
import base64
7176
import json
72-
import shutil
7377
from io import BytesIO
7478
from pathlib import Path
7579

@@ -78,7 +82,6 @@ def init():
7882
import torch
7983
import tyro
8084
from datasets import load_dataset
81-
from huggingface_hub import snapshot_download
8285
from llmcompressor import oneshot
8386
from llmcompressor.modeling.moe_context import moe_calibration_context
8487
from llmcompressor.modifiers.quantization import QuantizationModifier
@@ -113,6 +116,25 @@ class Args(pydantic.BaseModel):
113116
"""Seed to use for random number generator."""
114117

115118

119+
def _hf_download(cmd_args: list[str]) -> str:
120+
"""Run Hugging Face CLI download command and return the local path.
121+
122+
Uses a newer Hugging Face CLI version to download checkpoint. The dependency
123+
version is very old and not robust.
124+
"""
125+
cmd = [
126+
"uvx",
127+
f"hf>={_MINIMUM_HF_CLI_VERSION}",
128+
"download",
129+
*cmd_args,
130+
]
131+
print(f"{shlex.join(cmd)}")
132+
subprocess.check_call(cmd, text=True)
133+
return subprocess.check_output(
134+
[*cmd, "--quiet"], text=True, env=dict(os.environ) | {"HF_HUB_OFFLINE": "1"}
135+
).strip()
136+
137+
116138
def preprocess_and_tokenize(
117139
example: dict, processor: AutoProcessor, max_sequence_length: int
118140
) -> dict:
@@ -245,6 +267,14 @@ def remove_keys(d, keys_to_remove):
245267

246268

247269
def quantize(args: Args):
270+
print("Pre-downloading dataset: lmms-lab/flickr30k")
271+
_hf_download(["lmms-lab/flickr30k", "--repo-type", "dataset"])
272+
if os.path.exists(args.model):
273+
model_path = Path(args.model)
274+
else:
275+
print(f"Pre-downloading model: {args.model}")
276+
model_path = Path(_hf_download([args.model]))
277+
248278
args.output_dir.mkdir(parents=True, exist_ok=True)
249279

250280
model = Qwen3VLForConditionalGeneration.from_pretrained(
@@ -293,28 +323,22 @@ def quantize(args: Args):
293323
config_path = output_dir / "config.json"
294324
print(f"Postprocessing config file {config_path}...")
295325
postprocess_config(config_path)
296-
if not (model_path := Path(args.model)).exists():
297-
# path for remote model / HF ID
298-
snapshot_download(
299-
repo_id=args.model,
300-
ignore_patterns=["config.json", "*.safetensors*"],
301-
local_dir=output_dir,
302-
)
303-
else:
304-
# path for local model directory
305-
files_to_copy = [
306-
f
307-
for f in model_path.glob("*")
308-
if f.name != "config.json"
309-
and "safetensors" not in f.name
310-
and not f.is_dir()
311-
]
312-
for file in files_to_copy:
313-
shutil.copy(file, output_dir / file.name)
326+
shutil.copytree(
327+
model_path,
328+
output_dir,
329+
ignore=lambda dir, files: [
330+
f for f in files if f == "config.json" or "safetensors" in f
331+
],
332+
dirs_exist_ok=True,
333+
)
314334
print(f"Quantization complete! Model saved to: {output_dir}")
315335

316336

317337
def main():
338+
from loguru import logger as loguru_logger
339+
340+
loguru_logger.remove()
341+
318342
args = tyro.cli(Args, description=__doc__)
319343
quantize(args)
320344

0 commit comments

Comments
 (0)