4949"""
5050
5151import os
52+ import shlex
53+ import shutil
54+ import subprocess
5255from typing import Annotated , Literal
5356
57+ _MINIMUM_HF_CLI_VERSION = "1.3.5"
5458
55- def init ():
56- import logging
57- import warnings
5859
60+ def init ():
5961 os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
6062 os .environ ["TORCH_LOGS" ] = "-dynamo"
61- os .environ ["LOGURU_LEVEL" ] = "ERROR"
63+
64+ import logging
65+ import warnings
6266
6367 warnings .filterwarnings ("ignore" )
6468 logging .basicConfig (level = logging .ERROR )
6569
6670
6771init ()
72+ print ("Loading dependencies..." )
6873
6974
7075import base64
7176import json
72- import shutil
7377from io import BytesIO
7478from pathlib import Path
7579
@@ -78,7 +82,6 @@ def init():
7882import torch
7983import tyro
8084from datasets import load_dataset
81- from huggingface_hub import snapshot_download
8285from llmcompressor import oneshot
8386from llmcompressor .modeling .moe_context import moe_calibration_context
8487from llmcompressor .modifiers .quantization import QuantizationModifier
@@ -113,6 +116,25 @@ class Args(pydantic.BaseModel):
113116 """Seed to use for random number generator."""
114117
115118
119+ def _hf_download (cmd_args : list [str ]) -> str :
120+ """Run Hugging Face CLI download command and return the local path.
121+
122+ Uses a newer Hugging Face CLI version to download checkpoint. The dependency
123+ version is very old and not robust.
124+ """
125+ cmd = [
126+ "uvx" ,
127+ f"hf>={ _MINIMUM_HF_CLI_VERSION } " ,
128+ "download" ,
129+ * cmd_args ,
130+ ]
131+ print (f"{ shlex .join (cmd )} " )
132+ subprocess .check_call (cmd , text = True )
133+ return subprocess .check_output (
134+ [* cmd , "--quiet" ], text = True , env = dict (os .environ ) | {"HF_HUB_OFFLINE" : "1" }
135+ ).strip ()
136+
137+
116138def preprocess_and_tokenize (
117139 example : dict , processor : AutoProcessor , max_sequence_length : int
118140) -> dict :
@@ -245,6 +267,14 @@ def remove_keys(d, keys_to_remove):
245267
246268
247269def quantize (args : Args ):
270+ print ("Pre-downloading dataset: lmms-lab/flickr30k" )
271+ _hf_download (["lmms-lab/flickr30k" , "--repo-type" , "dataset" ])
272+ if os .path .exists (args .model ):
273+ model_path = Path (args .model )
274+ else :
275+ print (f"Pre-downloading model: { args .model } " )
276+ model_path = Path (_hf_download ([args .model ]))
277+
248278 args .output_dir .mkdir (parents = True , exist_ok = True )
249279
250280 model = Qwen3VLForConditionalGeneration .from_pretrained (
@@ -293,28 +323,22 @@ def quantize(args: Args):
293323 config_path = output_dir / "config.json"
294324 print (f"Postprocessing config file { config_path } ..." )
295325 postprocess_config (config_path )
296- if not (model_path := Path (args .model )).exists ():
297- # path for remote model / HF ID
298- snapshot_download (
299- repo_id = args .model ,
300- ignore_patterns = ["config.json" , "*.safetensors*" ],
301- local_dir = output_dir ,
302- )
303- else :
304- # path for local model directory
305- files_to_copy = [
306- f
307- for f in model_path .glob ("*" )
308- if f .name != "config.json"
309- and "safetensors" not in f .name
310- and not f .is_dir ()
311- ]
312- for file in files_to_copy :
313- shutil .copy (file , output_dir / file .name )
326+ shutil .copytree (
327+ model_path ,
328+ output_dir ,
329+ ignore = lambda dir , files : [
330+ f for f in files if f == "config.json" or "safetensors" in f
331+ ],
332+ dirs_exist_ok = True ,
333+ )
314334 print (f"Quantization complete! Model saved to: { output_dir } " )
315335
316336
317337def main ():
338+ from loguru import logger as loguru_logger
339+
340+ loguru_logger .remove ()
341+
318342 args = tyro .cli (Args , description = __doc__ )
319343 quantize (args )
320344
0 commit comments