-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Support Step3.5-Flash #19283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Support Step3.5-Flash #19283
Conversation
|
Adding supplemental evaluation results for reference. Performancehttps://github.com/stepfun-ai/Step-3.5-Flash/blob/main/llama.cpp/docs/step3.5-flash.md AccuracyAccuracy was evaluated against a BF16 vLLM baseline. Test the maximum 256k context on 8 * H200 devices
Test the maximum 256k context on Mac Studio
|
|
is this exactly a same modification did in the forked step llama.cpp ? or its a new one ? |
|
@gopinath87607 The register name (step3p5) was modified in the convert_hf_to_gguf part. Everything else is exactly the same. |
|
I pulled and compiled with this commit, then produced a BF16 with convert_hf_to_gguf, then attempted to imatrix it and the results were looking very suspect: llama-imatrix output on commit `2f0f12e70`I canceled it because the partial data for the experts and the 80,000+ PPL make it seem like something has gone wrong in the conversion or inference process somewhere. |
The same issue, about 'tool_call'. Edited: However, the result is correct; it indeed helped me write the HTML game I wanted. @forforever73 |
|
running with speed i'm getting: This is on Epyc 9274f \ 12*32Gb 4800 MT/s \ dual Nvidia A5000 |
|
@AesSedai Sorry about that. For now, please use the pre-quantized GGUF model: https://huggingface.co/stepfun-ai/Step-3.5-Flash-Int4 |
|
Tool calling is still missing some support in llama.cpp at the moment. I’ll submit the next PR to address this as soon as possible 💪🙂 |
After testing, I found that this bug occurs when more MCP tools are provided. If there is only one (perhaps) MCP tool, this issue does not occur. |
Looking forward to it! This is the best LLM I could run locally so far, thank you for it! |
|
@tarruda I do share your thoughts. This model seems extremely intelligent. Running ~16tok/s with 2xRTX3090 and 128GB DDR4. Makes me want to invest in Pro 6000 Blackwells lmao! |
|
If someone wants a version with fully working reasoning + tool calling, I've added a cherry-picked version of my autoparser branch. Already tested with OpenCode and works great so far. https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun |
Thank you @pwilkin, will use that branch for now! |
This doesn't compiling for me: |
|
@drrros sorry, forgot to commit that fix, try now. |
I ran into the same issue but taking https://github.com/pwilkin/llama.cpp/tree/autoparser and then cherry-picking this MR's commit on top worked for me. I do occasionally see "Invalid diff:" exceptions. A tool "string" parameter (which happens to consist of only digits; is incidentally also a legal integer) is shown once with and once without quotes. |
|
@pwilkin Compiling now, thanks |
That's a good debug case, could you possibly paste it here? |
|
@pwilkin I tried your branch and it does fix the tool call issue — thanks! |
gguf-py/gguf/gguf_writer.py
Outdated
| def add_rope_scaling_apply_mask(self, yarn_only_types: Sequence[str] | None) -> None: | ||
| apply_mask = 0x3 # default: apply on all layers (backwards compatible) | ||
| if isinstance(yarn_only_types, list): | ||
| apply_mask = 0 | ||
| if "full_attention" in yarn_only_types: | ||
| apply_mask |= 0x1 | ||
| if "sliding_attention" in yarn_only_types: | ||
| apply_mask |= 0x2 | ||
| self.add_uint32(Keys.Rope.SCALING_APPLY_MASK.format(arch=self.arch), int(apply_mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is too hacky IMO, we already had a notion of hparams.swa_layers and it should be used instead. See MiMo2 model for an example:
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
src/models/step35-iswa.cpp
Outdated
| const uint32_t apply_mask = hparams.rope_scaling_apply_mask; | ||
| if ((is_swa && (apply_mask & 0x2)) || (!is_swa && (apply_mask & 0x1))) { | ||
| rope_factors = model.get_rope_factors(cparams, il); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the use of the word "mask" here is quite ambiguous and can be interpreted as something like attention mask. either rename it to "bitmask" or better, don't use bit mask, save it as dedicated std::array<bool, LLAMA_MAX_LAYERS>
you really don't need a mask here, the is_swa already provided the info about SWA layer, and the rope_scaling can be just an array of bool
src/llama-hparams.h
Outdated
| std::array<float, LLAMA_MAX_LAYERS> swiglu_limits; | ||
| std::array<float, LLAMA_MAX_LAYERS> swiglu_limits_shared; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite a fan of calling the same thing using different names.
This is just clamping, even the python code calls it "clamp". I don't care about how config.json call it.
| std::array<float, LLAMA_MAX_LAYERS> swiglu_limits; | |
| std::array<float, LLAMA_MAX_LAYERS> swiglu_limits_shared; | |
| std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_exp; // clamping for expert FFN | |
| std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_shexp; // shared exp |
src/llama-model.cpp
Outdated
| format("%s.rope.scaling.apply_mask", ml.get_arch_name().c_str()), | ||
| hparams.rope_scaling_apply_mask, | ||
| false | ||
| ); | ||
|
|
||
| hparams.has_rope_freq_base_per_layer = ml.get_key_or_arr( | ||
| format("%s.rope.freq_base_per_layer", ml.get_arch_name().c_str()), | ||
| hparams.rope_freq_base_per_layer, | ||
| hparams.n_layer, | ||
| false | ||
| ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add these as proper LLM_KV_* like all other models
I'm refactoring the parser in general so that it handles new typical templates automatically (and I tackle a few edge cases that are annoying during agentic coding). It's just that the model doesn't have a dedicated parser in master yet (which is how things were done till now). |
@pwilkin I think with the following test case I consistently see the Test case herecurl -N http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "step-3.5-flash",
"stream": false,
"messages": [
{"role": "user", "content": "call the magic tool with ref 5123123 and name fooBar"}
],
"tools": [
{
"type": "function",
"function": {
"name": "magic",
"description": "Magic tool that takes a hash",
"parameters": {
"type": "object",
"properties": {
"name": {"type": "string"},
"ref": {"type": "string"}
},
"required": ["name", "ref"]
}
}
}
],
"tool_choice": "auto"
}'Relevant output (reformatted for readability): {
"tool_calls": [
{
"type": "function",
"function": {
"name": "magic",
"arguments": {
"name": "fooBar",
"ref": 5123123
}
},
"id": "EmNp5CqLXcPOl91dF0OqiEIYsZyQ1TY3"
}
]
}Tool schema says both parameters should be strings but |
|
@ngladitz Yeah, good case, thanks, will fix and add to tests. |
|
@ngladitz BTW it is, in a sense, a streaming problem: the tool reads the input as a string, but then parses it as a number, so there's a divergence between the partial parse result (which is a string) and the final result (which is a number) since in JSON rendering the string version isn't a prefix of the number version, hence the error. |
|
@ngladitz aight, can you please check if the newest commit on https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun properly supports streaming with that tool? |
@pwilkin thank you that seems to have fixed both my reduced test case as well as my actual use ❤️ |
|
@forforever73 I've re-converted the BF16 and am doing a new imatrix and the values look correct now, PPL is approx 3-4 now and the experts are showing much better data coverage. Thanks! |
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
convert_hf_to_gguf.py
Outdated
| kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] | ||
| swa_pat = [1 if lt == "sliding_attention" else 0 for lt in layer_types] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] | |
| swa_pat = [1 if lt == "sliding_attention" else 0 for lt in layer_types] | |
| kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] | |
| swa_pat = [lt == "sliding_attention" for lt in layer_types] |
You need to change this otherwise CI will fail.






This PR adds support for the Step3.5-Flash model architecture.
github:
huggingface: