From e12b8be9fca41c580582c0ebefbcd1f41e36cb6f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Tue, 21 Oct 2025 14:40:19 +0200 Subject: [PATCH 1/5] [#SAMPLE-9] Stateful DFA processor https://bitcrowd.atlassian.net/browse/SAMPLE-9 From ad19419d6162b107e460a0f441c63095b90c5b66 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Tue, 21 Oct 2025 15:17:52 +0200 Subject: [PATCH 2/5] Implement dfa_processor --- lib/bumblebee/text/generation.ex | 3 ++ .../text/generation/logits_processing.ex | 43 +++++++++++++++++++ lib/bumblebee/text/generation_config.ex | 4 ++ 3 files changed, 50 insertions(+) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 669c1b7e..8bcffbe4 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -371,6 +371,9 @@ defmodule Bumblebee.Text.Generation do if config.forced_token_ids do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) end, + if config.dfa do + &dfa_processor(&1, &2, dfa: config.dfa) + end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index eff38e52..fb50399b 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -3,6 +3,49 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do import Nx.Defn + deftransform dfa_processor(logits, context, opts \\ []) do + opts = Keyword.validate!(opts, [:dfa]) + dfa = opts[:dfa] + + num_states = + dfa.state_transitions + |> Enum.flat_map(fn {state, _token_id, next_state} -> [state, next_state] end) + |> Enum.uniq() + |> Enum.count() + + empty_state_transitions_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)}) + + state_transitions_tensor = + for transition <- dfa.state_transitions, reduce: empty_state_transitions_tensor do + transitions_tensor -> + {current_state, token_id, next_state} = transition + index = Nx.tensor([current_state, token_id]) + + Nx.indexed_put(transitions_tensor, index, next_state) + end + + initial_state = Nx.tensor([dfa.initial_state]) |> Nx.vectorize(:batch) + + current_state = + if context.length == context.input_length do + initial_state + else + last_state = context.logits_processor_state.dfa + last_token_id = context.sequence[Nx.subtract(context.length, 1)] + + state_transitions_tensor[[last_state, last_token_id]] |> Nx.squeeze() + end + + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + allowed_token_ids = state_transitions_tensor[current_state] + + logits = Nx.select(allowed_token_ids, logits, suppressed_logits) + + context = put_in(context, [:logits_processor_state, :dfa], current_state) + + {logits, context} + end + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index d7a6a9a0..78ba9d6e 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -93,6 +93,10 @@ defmodule Bumblebee.Text.GenerationConfig do default: [], doc: "a list of token ids to suppress during generation" ], + dfa: [ + default: nil, + doc: "the definition of a deterministic finite automaton (DFA) for constrained generation" + ], no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" From d0747a78a999059b8d60bfb83a330bf868585def Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 24 Oct 2025 18:09:31 +0200 Subject: [PATCH 3/5] dfa_processor as module based processor --- lib/bumblebee/text/generation.ex | 2 +- .../text/generation/dfa_processor.ex | 103 ++++++++++++++++++ .../text/generation/logits_processing.ex | 43 -------- test/bumblebee/text/generation_test.exs | 62 +++++++++++ 4 files changed, 166 insertions(+), 44 deletions(-) create mode 100644 lib/bumblebee/text/generation/dfa_processor.ex diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 8bcffbe4..a84a665b 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -372,7 +372,7 @@ defmodule Bumblebee.Text.Generation do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) end, if config.dfa do - &dfa_processor(&1, &2, dfa: config.dfa) + Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor, config.dfa) end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) diff --git a/lib/bumblebee/text/generation/dfa_processor.ex b/lib/bumblebee/text/generation/dfa_processor.ex new file mode 100644 index 00000000..2ddf2416 --- /dev/null +++ b/lib/bumblebee/text/generation/dfa_processor.ex @@ -0,0 +1,103 @@ +defmodule Bumblebee.Text.Generation.DFAProcessor do + @moduledoc false + + import Nx.Defn + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + initial_state: [ + default: nil, + doc: "the initial state" + ], + state_transitions: [ + default: nil, + doc: "the definition of a deterministic finite automaton used for constrained generation" + ], + vocab_size: [ + default: nil, + doc: "the size of the vocabulary" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(logits_processor, _context) do + dfa = logits_processor + + num_states = + dfa.state_transitions + |> Enum.flat_map(fn {state, _token_id, next_state} -> [state, next_state] end) + |> Enum.uniq() + |> Enum.count() + + empty_state_transitions_tensor = Nx.broadcast(0, {num_states, dfa.vocab_size}) + + state_transitions_tensor = + for transition <- dfa.state_transitions, reduce: empty_state_transitions_tensor do + transitions_tensor -> + {current_state, token_id, next_state} = transition + index = Nx.tensor([current_state, token_id]) + + Nx.indexed_put(transitions_tensor, index, next_state) + end + + initial_state = + List.wrap(dfa.initial_state) + |> Enum.map(&List.wrap(&1)) + |> Nx.tensor() + + transition_tensors = state_transitions_tensor + + %{ + dfa_state: %{ + last_state: initial_state, + state_transitions_tensor: transition_tensors + } + } + end + + @impl Bumblebee.LogitsProcessor + def process(_logits_processor, state, logits, context) do + dfa_processing(logits, state, context) + end + + deftransform dfa_processing(logits, state, context) do + transitions_tensor = state.dfa_state.state_transitions_tensor + + last_state = state.dfa_state.last_state |> Nx.vectorize(:batch) + current_state = current_state(context, last_state, transitions_tensor) + logits = logits(logits, transitions_tensor, current_state) + + current_state = Nx.devectorize(current_state, keep_names: false) + + dfa_state = %{state.dfa_state | last_state: current_state} + + state = %{state | dfa_state: dfa_state} + + {logits, state} + end + + defnp current_state(context, last_state, transitions_tensor) do + if context.length == context.input_length do + last_state + else + last_token_id = context.sequence[context.length - 1] + transitions_tensor[[Nx.squeeze(last_state), last_token_id]] + end + end + + defnp logits(logits, transitions_tensor, current_state) do + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + allowed_token_ids = transitions_tensor[Nx.squeeze(current_state)] + + Nx.select(allowed_token_ids, logits, suppressed_logits) + end +end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index fb50399b..eff38e52 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -3,49 +3,6 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do import Nx.Defn - deftransform dfa_processor(logits, context, opts \\ []) do - opts = Keyword.validate!(opts, [:dfa]) - dfa = opts[:dfa] - - num_states = - dfa.state_transitions - |> Enum.flat_map(fn {state, _token_id, next_state} -> [state, next_state] end) - |> Enum.uniq() - |> Enum.count() - - empty_state_transitions_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)}) - - state_transitions_tensor = - for transition <- dfa.state_transitions, reduce: empty_state_transitions_tensor do - transitions_tensor -> - {current_state, token_id, next_state} = transition - index = Nx.tensor([current_state, token_id]) - - Nx.indexed_put(transitions_tensor, index, next_state) - end - - initial_state = Nx.tensor([dfa.initial_state]) |> Nx.vectorize(:batch) - - current_state = - if context.length == context.input_length do - initial_state - else - last_state = context.logits_processor_state.dfa - last_token_id = context.sequence[Nx.subtract(context.length, 1)] - - state_transitions_tensor[[last_state, last_token_id]] |> Nx.squeeze() - end - - suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) - allowed_token_ids = state_transitions_tensor[current_state] - - logits = Nx.select(allowed_token_ids, logits, suppressed_logits) - - context = put_in(context, [:logits_processor_state, :dfa], current_state) - - {logits, context} - end - deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 3eff4dba..b0d6d7d6 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -107,6 +107,68 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end + test "DFA processor" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) + attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) + seed = Nx.tensor([0]) + + inputs = %{ + "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), + "attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]), + "seed" => Nx.Batch.concatenate([seed, seed]) + } + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 4) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ + Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor, + initial_state: [0, 1], + state_transitions: [ + {0, 1, 1}, + {1, 2, 2}, + {2, 1, 1} + ], + vocab_size: spec.vocab_size + ) + ] + ) + + %{token_ids: token_ids} = generate.(params, inputs) + + # according to DFA definition + # first batch entry starts in state 0 + + # first token_id should be 1 + assert_equal(token_ids[[0, 0]], 1) + + # second token_id should be 2 + assert_equal(token_ids[[0, 1]], 2) + + # third token_id should be 1 + assert_equal(token_ids[[0, 2]], 1) + + # second batch entry starts in state 1 + + # first token_id should be 2 + assert_equal(token_ids[[1, 0]], 2) + + # second token_id should be 1 + assert_equal(token_ids[[1, 1]], 1) + + # third token_id should be 2 + assert_equal(token_ids[[1, 2]], 2) + end + test "with stateful logits processor with different batch sizes" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) From 2d99571c711f5b56f24e1778ab36c62ae174369b Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Tue, 28 Oct 2025 16:48:14 +0100 Subject: [PATCH 4/5] DFA without state 0 --- .../text/generation/dfa_processor.ex | 4 ++- test/bumblebee/text/generation_test.exs | 34 +++++++++---------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/lib/bumblebee/text/generation/dfa_processor.ex b/lib/bumblebee/text/generation/dfa_processor.ex index 2ddf2416..883f442b 100644 --- a/lib/bumblebee/text/generation/dfa_processor.ex +++ b/lib/bumblebee/text/generation/dfa_processor.ex @@ -38,7 +38,9 @@ defmodule Bumblebee.Text.Generation.DFAProcessor do |> Enum.uniq() |> Enum.count() - empty_state_transitions_tensor = Nx.broadcast(0, {num_states, dfa.vocab_size}) + # we add 1 to num_states as we want to have an empty row for state 0 + # 0 should represent "no transition" as this is the only false value in nx + empty_state_transitions_tensor = Nx.broadcast(0, {num_states + 1, dfa.vocab_size}) state_transitions_tensor = for transition <- dfa.state_transitions, reduce: empty_state_transitions_tensor do diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index b0d6d7d6..d27b0e39 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -132,11 +132,11 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor, - initial_state: [0, 1], + initial_state: [1, 2], state_transitions: [ - {0, 1, 1}, {1, 2, 2}, - {2, 1, 1} + {2, 3, 3}, + {3, 2, 2} ], vocab_size: spec.vocab_size ) @@ -146,27 +146,27 @@ defmodule Bumblebee.Text.GenerationTest do %{token_ids: token_ids} = generate.(params, inputs) # according to DFA definition - # first batch entry starts in state 0 + # first batch entry starts in state 1 - # first token_id should be 1 - assert_equal(token_ids[[0, 0]], 1) + # first token_id should be 2 + assert_equal(token_ids[[0, 0]], 2) - # second token_id should be 2 - assert_equal(token_ids[[0, 1]], 2) + # second token_id should be 3 + assert_equal(token_ids[[0, 1]], 3) - # third token_id should be 1 - assert_equal(token_ids[[0, 2]], 1) + # third token_id should be 2 + assert_equal(token_ids[[0, 2]], 2) - # second batch entry starts in state 1 + # second batch entry starts in state 2 - # first token_id should be 2 - assert_equal(token_ids[[1, 0]], 2) + # first token_id should be 3 + assert_equal(token_ids[[1, 0]], 3) - # second token_id should be 1 - assert_equal(token_ids[[1, 1]], 1) + # second token_id should be 2 + assert_equal(token_ids[[1, 1]], 2) - # third token_id should be 2 - assert_equal(token_ids[[1, 2]], 2) + # third token_id should be 3 + assert_equal(token_ids[[1, 2]], 3) end test "with stateful logits processor with different batch sizes" do From 7e80d4662b9a5ecbd934b88eedd3d883ee28200e Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 7 Nov 2025 15:10:01 +0100 Subject: [PATCH 5/5] vectorize state in init/2 --- .../text/generation/dfa_processor.ex | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/lib/bumblebee/text/generation/dfa_processor.ex b/lib/bumblebee/text/generation/dfa_processor.ex index 883f442b..5d80cbfd 100644 --- a/lib/bumblebee/text/generation/dfa_processor.ex +++ b/lib/bumblebee/text/generation/dfa_processor.ex @@ -29,7 +29,7 @@ defmodule Bumblebee.Text.Generation.DFAProcessor do end @impl Bumblebee.LogitsProcessor - def init(logits_processor, _context) do + def init(logits_processor, context) do dfa = logits_processor num_states = @@ -51,18 +51,12 @@ defmodule Bumblebee.Text.Generation.DFAProcessor do Nx.indexed_put(transitions_tensor, index, next_state) end - initial_state = - List.wrap(dfa.initial_state) - |> Enum.map(&List.wrap(&1)) - |> Nx.tensor() - - transition_tensors = state_transitions_tensor + initial_state = Nx.tensor(dfa.initial_state) + [initial_state, _sequence] = Nx.broadcast_vectors([initial_state, context.sequence]) %{ - dfa_state: %{ - last_state: initial_state, - state_transitions_tensor: transition_tensors - } + last_state: initial_state, + state_transitions_tensor: state_transitions_tensor } end @@ -72,19 +66,15 @@ defmodule Bumblebee.Text.Generation.DFAProcessor do end deftransform dfa_processing(logits, state, context) do - transitions_tensor = state.dfa_state.state_transitions_tensor + transitions_tensor = state.state_transitions_tensor + last_state = state.last_state - last_state = state.dfa_state.last_state |> Nx.vectorize(:batch) current_state = current_state(context, last_state, transitions_tensor) logits = logits(logits, transitions_tensor, current_state) - current_state = Nx.devectorize(current_state, keep_names: false) - - dfa_state = %{state.dfa_state | last_state: current_state} - - state = %{state | dfa_state: dfa_state} + state = %{state | last_state: current_state} - {logits, state} + {state, logits} end defnp current_state(context, last_state, transitions_tensor) do @@ -92,13 +82,13 @@ defmodule Bumblebee.Text.Generation.DFAProcessor do last_state else last_token_id = context.sequence[context.length - 1] - transitions_tensor[[Nx.squeeze(last_state), last_token_id]] + transitions_tensor[[last_state, last_token_id]] end end defnp logits(logits, transitions_tensor, current_state) do suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) - allowed_token_ids = transitions_tensor[Nx.squeeze(current_state)] + allowed_token_ids = transitions_tensor[current_state] Nx.select(allowed_token_ids, logits, suppressed_logits) end