diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 669c1b7e..a84a665b 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -371,6 +371,9 @@ defmodule Bumblebee.Text.Generation do if config.forced_token_ids do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) end, + if config.dfa do + Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor, config.dfa) + end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) end diff --git a/lib/bumblebee/text/generation/dfa_processor.ex b/lib/bumblebee/text/generation/dfa_processor.ex new file mode 100644 index 00000000..5d80cbfd --- /dev/null +++ b/lib/bumblebee/text/generation/dfa_processor.ex @@ -0,0 +1,95 @@ +defmodule Bumblebee.Text.Generation.DFAProcessor do + @moduledoc false + + import Nx.Defn + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + initial_state: [ + default: nil, + doc: "the initial state" + ], + state_transitions: [ + default: nil, + doc: "the definition of a deterministic finite automaton used for constrained generation" + ], + vocab_size: [ + default: nil, + doc: "the size of the vocabulary" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(logits_processor, context) do + dfa = logits_processor + + num_states = + dfa.state_transitions + |> Enum.flat_map(fn {state, _token_id, next_state} -> [state, next_state] end) + |> Enum.uniq() + |> Enum.count() + + # we add 1 to num_states as we want to have an empty row for state 0 + # 0 should represent "no transition" as this is the only false value in nx + empty_state_transitions_tensor = Nx.broadcast(0, {num_states + 1, dfa.vocab_size}) + + state_transitions_tensor = + for transition <- dfa.state_transitions, reduce: empty_state_transitions_tensor do + transitions_tensor -> + {current_state, token_id, next_state} = transition + index = Nx.tensor([current_state, token_id]) + + Nx.indexed_put(transitions_tensor, index, next_state) + end + + initial_state = Nx.tensor(dfa.initial_state) + [initial_state, _sequence] = Nx.broadcast_vectors([initial_state, context.sequence]) + + %{ + last_state: initial_state, + state_transitions_tensor: state_transitions_tensor + } + end + + @impl Bumblebee.LogitsProcessor + def process(_logits_processor, state, logits, context) do + dfa_processing(logits, state, context) + end + + deftransform dfa_processing(logits, state, context) do + transitions_tensor = state.state_transitions_tensor + last_state = state.last_state + + current_state = current_state(context, last_state, transitions_tensor) + logits = logits(logits, transitions_tensor, current_state) + + state = %{state | last_state: current_state} + + {state, logits} + end + + defnp current_state(context, last_state, transitions_tensor) do + if context.length == context.input_length do + last_state + else + last_token_id = context.sequence[context.length - 1] + transitions_tensor[[last_state, last_token_id]] + end + end + + defnp logits(logits, transitions_tensor, current_state) do + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + allowed_token_ids = transitions_tensor[current_state] + + Nx.select(allowed_token_ids, logits, suppressed_logits) + end +end diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index d7a6a9a0..78ba9d6e 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -93,6 +93,10 @@ defmodule Bumblebee.Text.GenerationConfig do default: [], doc: "a list of token ids to suppress during generation" ], + dfa: [ + default: nil, + doc: "the definition of a deterministic finite automaton (DFA) for constrained generation" + ], no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 3eff4dba..d27b0e39 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -107,6 +107,68 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end + test "DFA processor" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) + attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) + seed = Nx.tensor([0]) + + inputs = %{ + "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), + "attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]), + "seed" => Nx.Batch.concatenate([seed, seed]) + } + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 4) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ + Bumblebee.configure(Bumblebee.Text.Generation.DFAProcessor, + initial_state: [1, 2], + state_transitions: [ + {1, 2, 2}, + {2, 3, 3}, + {3, 2, 2} + ], + vocab_size: spec.vocab_size + ) + ] + ) + + %{token_ids: token_ids} = generate.(params, inputs) + + # according to DFA definition + # first batch entry starts in state 1 + + # first token_id should be 2 + assert_equal(token_ids[[0, 0]], 2) + + # second token_id should be 3 + assert_equal(token_ids[[0, 1]], 3) + + # third token_id should be 2 + assert_equal(token_ids[[0, 2]], 2) + + # second batch entry starts in state 2 + + # first token_id should be 3 + assert_equal(token_ids[[1, 0]], 3) + + # second token_id should be 2 + assert_equal(token_ids[[1, 1]], 2) + + # third token_id should be 3 + assert_equal(token_ids[[1, 2]], 3) + end + test "with stateful logits processor with different batch sizes" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})