Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/models/definitions/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(self, hidden_states):
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
return router_logits, router_scores, router_indices

class GPTOSSGPTQ(BaseQModel):
dynamic_expert_index = "num_local_experts"
Expand Down
14 changes: 11 additions & 3 deletions gptqmodel/nn_modules/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,26 @@ def convert_gpt_oss_expert_converter(module, config):
import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub

from ..models.definitions.gpt_oss import GptOssExpertsNew
from ..models.definitions.gpt_oss import GptOssExpertsNew, GptOssTopKRouterNew

@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
class GptOssMLPNew(nn.Module):
def __init__(self, config, ori_mlp=None):
super().__init__()
self.router = ori_mlp.router
self.router = GptOssTopKRouterNew(config, ori_mlp.router)
experts_new = GptOssExpertsNew(config, ori_mlp.experts)
self.experts = experts_new

def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
router_output = self.router(hidden_states)
if isinstance(router_output, tuple) and len(router_output) == 3:
_, router_scores, router_indices = router_output
elif isinstance(router_output, tuple) and len(router_output) == 2:
router_scores, router_indices = router_output
else:
raise ValueError(
f"Unexpected GPT-OSS router output during conversion: {type(router_output).__name__}"
)
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
return routed_out, router_scores

Expand Down
1 change: 1 addition & 0 deletions tests/models/test_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class TestGPTOSS(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/gpt-oss-20b-BF16/"
USE_FLASH_ATTN = False
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
"chat_template": False,
Expand Down
118 changes: 118 additions & 0 deletions tests/test_model_test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch

from models.model_test import ModelTest


class FakeBatchEncoding(dict):
def __init__(self, input_ids):
super().__init__(input_ids=input_ids)
self.input_ids = input_ids

def to(self, _device):
return self


class FakeTokenizer:
pad_token_id = None
eos_token_id = 7

def __init__(self):
self.decode_calls = []
self.batch_decode_calls = []

def __call__(self, prompt, return_tensors="pt"):
assert return_tensors == "pt"
assert prompt == "hello"
return FakeBatchEncoding(torch.tensor([[101, 102]]))

def decode(self, tokens, skip_special_tokens=True):
self.decode_calls.append(
{
"tokens": tokens.tolist(),
"skip_special_tokens": skip_special_tokens,
}
)
return f"decoded:{tokens.tolist()}"

def batch_decode(self, sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False):
self.batch_decode_calls.append(
{
"sequences": [seq.tolist() for seq in sequences],
"skip_special_tokens": skip_special_tokens,
"clean_up_tokenization_spaces": clean_up_tokenization_spaces,
}
)
return [f"batch:{[seq.tolist() for seq in sequences]}"]


class FakeProcessor(FakeTokenizer):
pass


class FakeModel:
def __init__(self, generated):
self.device = "cpu"
self.generated = generated
self.calls = []

def generate(self, **kwargs):
self.calls.append(kwargs)
return self.generated


def test_generate_stable_with_limit_for_prompt_uses_deterministic_kwargs():
tokenizer = FakeTokenizer()
model = FakeModel(torch.tensor([[101, 102, 103, 104]]))

output = ModelTest.generate_stable_with_limit(
model,
tokenizer,
"hello",
min_new_tokens=2,
max_new_tokens=4,
skip_special_tokens=False,
)

assert output == "decoded:[101, 102, 103, 104]"
assert len(model.calls) == 1
assert model.calls[0]["do_sample"] is False
assert model.calls[0]["num_beams"] == 1
assert model.calls[0]["min_new_tokens"] == 2
assert model.calls[0]["max_new_tokens"] == 4
assert model.calls[0]["pad_token_id"] == tokenizer.eos_token_id
assert model.calls[0]["eos_token_id"] == tokenizer.eos_token_id
assert tokenizer.decode_calls == [
{
"tokens": [101, 102, 103, 104],
"skip_special_tokens": False,
}
]


def test_generate_stable_with_limit_for_prepared_inputs_batch_decodes_suffix():
processor = FakeProcessor()
prepared_inputs = FakeBatchEncoding(torch.tensor([[10, 11]]))
model = FakeModel(torch.tensor([[10, 11, 21, 22]]))

output = ModelTest.generate_stable_with_limit(
model,
processor,
inputs=prepared_inputs,
prompt=None,
batch_decode=True,
max_new_tokens=2,
clean_up_tokenization_spaces=False,
)

assert output == "batch:[[21, 22]]"
assert len(model.calls) == 1
assert model.calls[0]["input_ids"].tolist() == [[10, 11]]
assert model.calls[0]["do_sample"] is False
assert model.calls[0]["num_beams"] == 1
assert processor.batch_decode_calls == [
{
"sequences": [[21, 22]],
"skip_special_tokens": True,
"clean_up_tokenization_spaces": False,
}
]
Loading
Loading