Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions gptqmodel/models/definitions/base_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,53 @@ class BaseQwen3VLGPTQ(BaseQModel):

require_load_processor = True

def _language_model_root(self):
if hasattr(self.model, "language_model"):
return self.model
if hasattr(self.model, "model") and hasattr(self.model.model, "language_model"):
return self.model.model
raise AttributeError(
f"{type(self.model).__name__} does not expose a language_model root"
)

def pre_quantize_generate_hook_start(self):
self.model.language_model.embed_tokens = move_to(self.model.language_model.embed_tokens, device=self.quantize_config.device)
self.model.language_model.rotary_emb = move_to(self.model.language_model.rotary_emb, device=self.quantize_config.device)
self.model.visual = move_to(self.model.visual, device=self.quantize_config.device)
model_root = self._language_model_root()
model_root.language_model.embed_tokens = move_to(
model_root.language_model.embed_tokens,
device=self.quantize_config.device,
)
model_root.language_model.rotary_emb = move_to(
model_root.language_model.rotary_emb,
device=self.quantize_config.device,
)
model_root.visual = move_to(model_root.visual, device=self.quantize_config.device)

def pre_quantize_generate_hook_end(self):
model_root = self._language_model_root()
if self.quantize_config.offload_to_disk:
offload_to_disk(model=self.model.language_model,
module=self.model.language_model.embed_tokens,
offload_to_disk(model=model_root.language_model,
module=model_root.language_model.embed_tokens,
disk_path=self.quantize_config.offload_to_disk_path,
)
offload_to_disk(model=self.model.language_model,
module=self.model.language_model.rotary_emb,
offload_to_disk(model=model_root.language_model,
module=model_root.language_model.rotary_emb,
disk_path=self.quantize_config.offload_to_disk_path,
)
offload_to_disk(model=self.model,
module=self.model.visual,
module=model_root.visual,
disk_path=self.quantize_config.offload_to_disk_path,
)
return

self.model.language_model.embed_tokens = move_to(self.model.language_model.embed_tokens, device=CPU)
self.model.language_model.rotary_emb = move_to(self.model.language_model.rotary_emb, device=CPU)
self.model.visual = move_to(self.model.visual, device=CPU)
model_root.language_model.embed_tokens = move_to(
model_root.language_model.embed_tokens,
device=CPU,
)
model_root.language_model.rotary_emb = move_to(
model_root.language_model.rotary_emb,
device=CPU,
)
model_root.visual = move_to(model_root.visual, device=CPU)

@staticmethod
def process_vision_info(
Expand Down
29 changes: 24 additions & 5 deletions gptqmodel/utils/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,21 @@ def _module_all_meta(mod: nn.Module) -> bool:
def _is_leaf(mod: nn.Module) -> bool:
return next(mod.named_children(), None) is None

def _prune_stale_meta_buffers(shell_sub: nn.Module, turtle_sub: nn.Module) -> int:
removed = 0
turtle_buffers = {
name for name, _ in turtle_sub.named_buffers(recurse=False)
}
for name, buf in list(shell_sub.named_buffers(recurse=False)):
if name in turtle_buffers:
continue
if not _is_meta_tensor(buf):
continue
if name in getattr(shell_sub, "_buffers", {}):
del shell_sub._buffers[name]
removed += 1
return removed

def alias_all_from_turtle_if_meta(
shell_model: nn.Module,
turtle_model: nn.Module,
Expand All @@ -640,14 +655,18 @@ def alias_all_from_turtle_if_meta(
for qname, shell_sub in list(shell_model.named_modules()):
if not qname: # skip root
continue
if not _is_leaf(shell_sub):
continue
if not _module_all_meta(shell_sub):
continue

turtle_sub = turtle_map.get(qname, None)
if turtle_sub is None:
# log.info(f"Module: Skipped {qname}: not found in turtle model")
continue

removed = _prune_stale_meta_buffers(shell_sub, turtle_sub)
if removed:
log.info(f"Module: Pruned {removed} stale meta buffer(s) from {qname}")

if not _is_leaf(shell_sub):
continue
if not _module_all_meta(shell_sub):
continue

if require_class_match and (shell_sub.__class__ is not turtle_sub.__class__):
Expand Down
13 changes: 13 additions & 0 deletions model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import importlib.util
from pathlib import Path


_MODEL_TEST_PATH = Path(__file__).resolve().parent / "tests" / "models" / "model_test.py"
_SPEC = importlib.util.spec_from_file_location("_model_test_impl", _MODEL_TEST_PATH)
if _SPEC is None or _SPEC.loader is None:
raise ImportError(f"Unable to load ModelTest from {_MODEL_TEST_PATH}")

_MODULE = importlib.util.module_from_spec(_SPEC)
_SPEC.loader.exec_module(_MODULE)

globals().update({name: getattr(_MODULE, name) for name in dir(_MODULE) if not name.startswith("_")})
1 change: 1 addition & 0 deletions tests/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models.model_test import * # noqa: F401,F403
1 change: 1 addition & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Package marker for pytest collection to avoid basename collisions.
1 change: 1 addition & 0 deletions tests/models/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Package marker for pytest collection to avoid basename collisions.
Loading