diff --git a/.gitignore b/.gitignore index 32da76df8..be879b728 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ docs/source/generated # docs/source/_static/model_table **.orig .venv - +.env diff --git a/.vscode/settings.json b/.vscode/settings.json index 63e6e310a..86d448657 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,7 +33,7 @@ "notebook.formatOnSave.enabled": true, "pylint.importStrategy": "fromEnvironment", "python.testing.pytestArgs": [ - "transformer_lens", + "tests" ], "python.testing.pytestEnabled": true, "rewrap.autoWrap.enabled": true, diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py index d0f746d60..797ecbbf9 100644 --- a/tests/acceptance/test_hooked_encoder.py +++ b/tests/acceptance/test_hooked_encoder.py @@ -225,6 +225,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") -def test_cuda(mlm_tokens): +def test_cuda(tokens): model = HookedEncoder.from_pretrained(MODEL_NAME) - model(mlm_tokens) + model(tokens) diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index 3af5eeeb2..ad407eb6e 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -111,7 +111,7 @@ def test_cache_device(): torch.device("cuda:1") ) - logits, cache = model.run_with_cache("Hello there", device="cpu") + logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu")) assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) model.to("cuda") diff --git a/tests/integration/test_generation_compatibility.py b/tests/integration/test_generation_compatibility.py new file mode 100644 index 000000000..5af4af9a7 --- /dev/null +++ b/tests/integration/test_generation_compatibility.py @@ -0,0 +1,509 @@ +"""Integration tests for generation API compatibility. + +This module tests generation API features including HuggingFace-style ModelOutput +support and TransformerBridge batch dimension compatibility. +""" + +import warnings + +import pytest +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.model_bridge import TransformerBridge + + +@pytest.fixture(scope="module") +def gpt2_ht(): + """Load GPT-2 HookedTransformer once per module.""" + return HookedTransformer.from_pretrained("gpt2", device="cpu") + + +@pytest.fixture(scope="module") +def gpt2_bridge(): + """Load GPT-2 TransformerBridge once per module.""" + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") + if bridge.tokenizer.pad_token is None: + bridge.tokenizer.pad_token = bridge.tokenizer.eos_token + return bridge + + +class TestHookedTransformerGenerationModelOutput: + """Tests for HookedTransformer generation with ModelOutput returns.""" + + def test_generate_with_output_logits_returns_modeloutput(self, gpt2_ht): + """Test that output_logits=True returns a ModelOutput with sequences and logits.""" + prompt = "The quick brown" + max_new_tokens = 5 + + result = gpt2_ht.generate( + prompt, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check that we got a ModelOutput-like object + assert hasattr(result, "sequences"), "Result should have sequences attribute" + assert hasattr(result, "logits"), "Result should have logits attribute" + + # Check sequences shape and type + assert isinstance(result.sequences, torch.Tensor), "sequences should be a tensor" + assert result.sequences.ndim == 2, "sequences should be 2D [batch, pos]" + + # Check logits structure and shape + assert isinstance(result.logits, tuple), "logits should be a tuple" + assert ( + len(result.logits) == max_new_tokens + ), f"logits tuple should have {max_new_tokens} elements" + + # Each logit tensor should be [batch, vocab] + for i, logit in enumerate(result.logits): + assert isinstance(logit, torch.Tensor), f"logits[{i}] should be a tensor" + assert logit.ndim == 2, f"logits[{i}] should be 2D [batch, vocab]" + assert ( + logit.shape[0] == result.sequences.shape[0] + ), f"logits[{i}] batch size should match sequences" + assert ( + logit.shape[1] == gpt2_ht.cfg.d_vocab + ), f"logits[{i}] vocab size should match model config" + + def test_generate_without_output_logits_returns_normal(self, gpt2_ht): + """Test that without output_logits flag, generation returns normal format.""" + prompt = "The quick brown" + + result = gpt2_ht.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + # Should return a string (default return_type="input" with string input) + assert isinstance(result, str), "Result should be a string" + assert len(result) > len(prompt), "Generated text should be longer than prompt" + + def test_generate_output_logits_with_return_type_tokens(self, gpt2_ht): + """Test output_logits with return_type='tokens' returns ModelOutput with token sequences.""" + prompt = "Hello world" + max_new_tokens = 3 + + result = gpt2_ht.generate( + prompt, + max_new_tokens=max_new_tokens, + return_type="tokens", + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check ModelOutput structure + assert hasattr(result, "sequences"), "Result should have sequences" + assert hasattr(result, "logits"), "Result should have logits" + + # Sequences should be tokens + assert isinstance(result.sequences, torch.Tensor), "sequences should be a tensor" + assert result.sequences.dtype in [ + torch.long, + torch.int, + torch.int64, + ], "sequences should be integer tokens" + + # Check logits + assert len(result.logits) == max_new_tokens, "logits should match max_new_tokens" + + def test_return_dict_in_generate_silently_ignored(self, gpt2_ht): + """Test that return_dict_in_generate is silently ignored without warnings.""" + prompt = "Test" + + # Should not raise any warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = gpt2_ht.generate( + prompt, + max_new_tokens=2, + verbose=False, + return_dict_in_generate=True, # Should be silently ignored + ) + + # Check no warnings were raised + assert len(w) == 0, "return_dict_in_generate should be silently ignored" + + # Result should still be normal (string) + assert isinstance(result, str), "Result should be a string" + + def test_unsupported_hf_flags_trigger_warning(self, gpt2_ht): + """Test that unsupported HF generation kwargs trigger UserWarning.""" + prompt = "Test" + + with pytest.warns(UserWarning, match="unsupported generation kwargs"): + result = gpt2_ht.generate( + prompt, + max_new_tokens=2, + verbose=False, + output_scores=True, # Unsupported flag + output_attentions=True, # Unsupported flag + ) + + # Result should still work (string) + assert isinstance(result, str), "Result should be a string despite warnings" + + def test_logits_consistency_with_forward_pass(self, gpt2_ht): + """Test that logits from generate match those from forward pass.""" + prompt = "Hello" + + # Generate with output_logits + result = gpt2_ht.generate( + prompt, + max_new_tokens=1, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Get first generated token from sequences + first_new_token = result.sequences[0, -1] + + # Get logits for that token + first_logits = result.logits[0][0] + + # The argmax of logits should match the generated token (since do_sample=False) + assert first_logits.argmax() == first_new_token, "Greedy token should match logits argmax" + + def test_output_logits_batch_generation(self, gpt2_ht): + """Test output_logits works with batch inputs.""" + prompts = ["Hello", "World"] + max_new_tokens = 3 + + result = gpt2_ht.generate( + prompts, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check batch dimension + assert result.sequences.shape[0] == len( + prompts + ), "Batch dimension should match number of prompts" + + # Check logits batch dimension + for logit in result.logits: + assert logit.shape[0] == len(prompts), "Logits batch dimension should match prompts" + + +class TestTransformerBridgeGenerationModelOutput: + """Tests for TransformerBridge generation with ModelOutput returns.""" + + def test_generate_with_output_logits_returns_modeloutput(self, gpt2_bridge): + """Test that output_logits=True returns a ModelOutput with sequences and logits.""" + prompt = "The quick brown" + max_new_tokens = 5 + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check that we got a ModelOutput-like object + assert hasattr(result, "sequences"), "Result should have sequences attribute" + assert hasattr(result, "logits"), "Result should have logits attribute" + + # Check sequences shape and type + assert isinstance(result.sequences, torch.Tensor), "sequences should be a tensor" + assert result.sequences.ndim == 2, "sequences should be 2D [batch, pos]" + + # Check logits structure and shape + assert isinstance(result.logits, tuple), "logits should be a tuple" + assert ( + len(result.logits) == max_new_tokens + ), f"logits tuple should have {max_new_tokens} elements" + + # Each logit tensor should be [batch, vocab] + for i, logit in enumerate(result.logits): + assert isinstance(logit, torch.Tensor), f"logits[{i}] should be a tensor" + assert logit.ndim == 2, f"logits[{i}] should be 2D [batch, vocab]" + assert ( + logit.shape[0] == result.sequences.shape[0] + ), f"logits[{i}] batch size should match sequences" + assert ( + logit.shape[1] == gpt2_bridge.cfg.d_vocab + ), f"logits[{i}] vocab size should match model config" + + def test_generate_without_output_logits_returns_normal(self, gpt2_bridge): + """Test that without output_logits flag, generation returns normal format.""" + prompt = "The quick brown" + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + # Should return a string (default return_type="input" with string input) + assert isinstance(result, str), "Result should be a string" + assert len(result) > len(prompt), "Generated text should be longer than prompt" + + def test_generate_output_logits_batch(self, gpt2_bridge): + """Test output_logits works with batch inputs.""" + prompts = ["Hello", "World"] + max_new_tokens = 3 + + result = gpt2_bridge.generate( + prompts, + max_new_tokens=max_new_tokens, + do_sample=False, + verbose=False, + output_logits=True, + ) + + # Check ModelOutput structure + assert hasattr(result, "sequences"), "Result should have sequences" + assert hasattr(result, "logits"), "Result should have logits" + + # Check batch dimension + assert result.sequences.shape[0] == len( + prompts + ), "Batch dimension should match number of prompts" + + # Check logits batch dimension + for logit in result.logits: + assert logit.shape[0] == len(prompts), "Logits batch dimension should match prompts" + + +class TestTransformerBridgeHFGenerate: + """Tests for TransformerBridge.hf_generate() with full HF API support.""" + + def test_hf_generate_with_output_scores(self, gpt2_bridge): + """Test that output_scores is forwarded to HF model.""" + prompt = "Test" + + # output_scores should be forwarded without error + result = gpt2_bridge.hf_generate( + prompt, + max_new_tokens=3, + do_sample=False, + output_scores=True, + ) + + # Should return a string (default behavior with string input) + assert isinstance(result, str), "Result should be a string" + + def test_hf_generate_sets_return_dict_in_generate(self, gpt2_bridge): + """Test that hf_dict_flags automatically set return_dict_in_generate=True.""" + prompt = "Hello" + + # When we pass output_logits, return_dict_in_generate should be auto-set + # We can't directly inspect the HF call, but we can verify it doesn't error + result = gpt2_bridge.hf_generate( + prompt, + max_new_tokens=2, + do_sample=False, + output_logits=True, + ) + + # Should work without error + assert isinstance(result, str), "Result should be generated successfully" + + def test_hf_generate_multiple_flags_simultaneously(self, gpt2_bridge): + """Test that multiple HF-style flags can be passed simultaneously.""" + prompt = "Test" + + result = gpt2_bridge.hf_generate( + prompt, + max_new_tokens=2, + do_sample=False, + output_logits=True, + output_attentions=True, + output_hidden_states=True, + ) + + # Should work and return a result + assert isinstance(result, str), "Result should be generated with multiple flags" + + def test_hf_generate_return_type_tokens(self, gpt2_bridge): + """Test return_type='tokens' works with HF flags.""" + prompt = "Hello" + + result = gpt2_bridge.hf_generate( + prompt, + max_new_tokens=2, + return_type="tokens", + do_sample=False, + output_logits=True, + ) + + # With return_type='tokens', we should get either tokens tensor or ModelOutput + # The implementation returns the raw HF output for tokens + assert result is not None, "Result should not be None" + + def test_hf_generate_flags_coerced_to_bool(self, gpt2_bridge): + """Test that HF flags are properly coerced to boolean values.""" + prompt = "Test" + + # Pass non-boolean values that should be coerced to bool + result = gpt2_bridge.hf_generate( + prompt, + max_new_tokens=2, + do_sample=False, + output_logits=1, # Should be coerced to True + output_scores=0, # 0 is not None, so flag is set (coerces to False) + ) + + # Should work without error + assert isinstance(result, str) or result is not None, "Result should be generated" + + def test_hf_generate_batch_generation(self, gpt2_bridge): + """Test batch generation works with HF-style flags.""" + prompts = ["Hello", "World"] + + result = gpt2_bridge.hf_generate( + prompts, + max_new_tokens=2, + do_sample=False, + output_logits=True, + ) + + # Should return list of strings for batch input + assert isinstance(result, list), "Batch input should return list" + assert len(result) == len(prompts), "Output list should match input length" + + +class TestGenerationBackwardCompatibility: + """Tests to ensure backward compatibility with existing generation usage.""" + + def test_hooked_transformer_basic_generation_unchanged(self, gpt2_ht): + """Test that basic generation without new flags works as before.""" + prompt = "Hello world" + + result = gpt2_ht.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + assert isinstance(result, str), "Basic generation should return string" + assert len(result) > len(prompt), "Generated text should be longer" + + def test_bridge_basic_generation_unchanged(self, gpt2_bridge): + """Test that basic bridge generation without new flags works as before.""" + prompt = "Hello world" + + result = gpt2_bridge.generate( + prompt, + max_new_tokens=5, + do_sample=False, + verbose=False, + ) + + assert isinstance(result, str), "Basic generation should return string" + assert len(result) > len(prompt), "Generated text should be longer" + + def test_hooked_transformer_return_types_unchanged(self, gpt2_ht): + """Test that all return_type options still work.""" + prompt = "Test" + + # Test return_type='str' + result_str = gpt2_ht.generate( + prompt, max_new_tokens=2, return_type="str", verbose=False, do_sample=False + ) + assert isinstance(result_str, str), "return_type='str' should return string" + + # Test return_type='tokens' + result_tokens = gpt2_ht.generate( + prompt, max_new_tokens=2, return_type="tokens", verbose=False, do_sample=False + ) + assert isinstance(result_tokens, torch.Tensor), "return_type='tokens' should return tensor" + + # Test return_type='embeds' + result_embeds = gpt2_ht.generate( + prompt, max_new_tokens=2, return_type="embeds", verbose=False, do_sample=False + ) + assert isinstance(result_embeds, torch.Tensor), "return_type='embeds' should return tensor" + assert result_embeds.ndim == 3, "Embeddings should be 3D" + + +class TestBlockBridgeBatchCompatibility: + """Tests for BlockBridge tuple return format and batch dimension preservation.""" + + def test_block_bridge_batched_generation_compatibility(self, gpt2_bridge): + """Test BlockBridge maintains tuple format and batch dimensions during generation. + + This test exercises two critical aspects of improved HF compatibility: + 1. BlockBridge.forward() always returns tuples (not bare tensors) + 2. Batch dimensions are preserved through multi-block generation pipeline + """ + # Test 1: Direct block forward returns tuple with preserved batch dimension + batch_size = 2 + seq_len = 8 + hidden_dim = gpt2_bridge.cfg.d_model + hidden_states = torch.randn(batch_size, seq_len, hidden_dim) + + # Get first transformer block (BlockBridge component) + first_block = gpt2_bridge.original_model.transformer.h[0] + + # Call forward - this is what HF's GPT2Model does in its loop + block_output = first_block(hidden_states) + + # BlockBridge must return tuple + assert isinstance( + block_output, tuple + ), f"BlockBridge must return tuple for HF compatibility, got {type(block_output)}" + + # Verify first element is a tensor + assert isinstance( + block_output[0], torch.Tensor + ), "First element of BlockBridge output must be a tensor" + + # Batch dimension must be preserved + # Without tuple wrapping, outputs[0] idx op would turn [batch, seq, dim] -> [seq, dim] + assert block_output[0].shape == ( + batch_size, + seq_len, + hidden_dim, + ), f"Expected shape [{batch_size}, {seq_len}, {hidden_dim}], got {block_output[0].shape}" + + assert ( + block_output[0].shape[0] == batch_size + ), f"Batch dimension lost! Expected {batch_size}, got {block_output[0].shape[0]}" + + # Test 2: Batched generation works end-to-end through multiple blocks + prompts = ["Hello world", "Goodbye world"] + + # Tokenize with left padding + tokens = gpt2_bridge.to_tokens(prompts, prepend_bos=False, padding_side="left") + + # Generate tokens - this exercises the full HF generation loop with multiple blocks + output = gpt2_bridge.generate( + tokens, + max_new_tokens=4, + do_sample=False, # Deterministic for testing + use_past_kv_cache=True, + verbose=False, + ) + + # Verify output preserves batch dimension + assert output.shape[0] == len( + prompts + ), f"Batch size must be preserved through generation. Expected {len(prompts)}, got {output.shape[0]}" + + # Verify we actually generated new tokens + assert ( + output.shape[1] > tokens.shape[1] + ), "Generation should produce longer sequences than input" + + # Verify batch items remain independent (not collapsed into single item) + assert not torch.equal( + output[0], output[1] + ), "Batch items should be independent - different prompts should produce different outputs" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index b386660c6..c473cc491 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -80,6 +80,25 @@ def test_attention_load_in_4bit(): assert torch.all(attn.b_V == 0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_attention_forward_half_precisions(dtype): + # Construct a small attention block + cfg = HookedTransformerConfig( + d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype + ) + attn = Attention(cfg) + # Random inputs in the matching dtype + batch = 1 + seq = 4 + x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda") + # Run forward through attention (q,k,v = x) + out = attn(x, x, x) + # Should not raise and return a tensor on cuda with same dtype as cfg or compatible + assert isinstance(out, torch.Tensor) + assert out.device.type == "cuda" + + def test_attention_config_dict(): cfg = { "n_layers": 12, diff --git a/tests/unit/factored_matrix/test_multiply_by_scalar.py b/tests/unit/factored_matrix/test_multiply_by_scalar.py index 85d0bfbe7..d5fbf29ba 100644 --- a/tests/unit/factored_matrix/test_multiply_by_scalar.py +++ b/tests/unit/factored_matrix/test_multiply_by_scalar.py @@ -23,6 +23,7 @@ ), # Non-scalar Tensor. AssertionError expected. (torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected. ], + ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"], ) @pytest.mark.parametrize("leading_dim", [False, True]) @pytest.mark.parametrize("multiply_from_left", [False, True]) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 4bcece7b1..0b8181a91 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -14,6 +14,7 @@ import logging import os from typing import ( + Any, Dict, List, NamedTuple, @@ -1840,11 +1841,15 @@ def generate( padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, return_type: Optional[str] = "input", verbose: bool = True, + **generation_kwargs, ) -> Union[ str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], + Any, # transformers.utils.ModelOutput to accommodate output_logits=True. + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 ]: """Sample Tokens from the Model. @@ -1943,6 +1948,34 @@ def generate( else: past_kv_cache = None + # We only support a single HF style generation kwarg: `output_logits` which will cause + # the function to return a ModelOutput-like object containing `sequences` and `logits`. + # Any other HF-style generation kwargs are rejected to avoid supporting the full HF API here. + output_logits_flag = False + if generation_kwargs: + if "output_logits" in generation_kwargs: + output_logits_flag = bool(generation_kwargs.pop("output_logits")) + # Identify keys to warn about: anything other than allowed/silently ignored + accepted_keys = {"output_logits", "return_dict_in_generate"} + unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys] + # silently ignore `return_dict_in_generate` + if "return_dict_in_generate" in generation_kwargs: + generation_kwargs.pop("return_dict_in_generate") + # If any unsupported keys remain, warn and ignore them + if unsupported_keys: + import warnings + + warnings.warn( + f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}", + UserWarning, + ) + # Clear unsupported keys + for k in unsupported_keys: + generation_kwargs.pop(k, None) + + # Optionally collect logits at each generation step for downstream tooling/tests + logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None + shortformer_pos_embed = None embeds = input if input_type == "embeds" else self.embed(input) @@ -2033,6 +2066,10 @@ def generate( ) final_logits = logits[:, -1, :] + if output_logits_flag: + assert logits_seq_list is not None + logits_seq_list.append(final_logits.clone()) + if do_sample: if input_type in [ "str", @@ -2089,11 +2126,35 @@ def generate( self.tokenizer.decode(tokens, skip_special_tokens=True) for tokens in output_tokens ] - return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts + result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts elif return_type == "tokens": - return output_tokens + result = cast(Any, output_tokens) + else: + result = cast(Any, embeds) + + if output_logits_flag: + # Adhere to HF ModelOutput format with sequences (tokens) and logits (per-step) + from transformers.utils import ModelOutput # type: ignore + + def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + assert logits_list is not None + # Convert list of [batch, vocab] tensors to tuple + return tuple(logits_list) + + try: + from transformers.generation.utils import GenerateDecoderOnlyOutput + + return GenerateDecoderOnlyOutput( + sequences=cast(torch.LongTensor, output_tokens), + # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] + logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] + ) + except (ImportError, AttributeError): + # Fallback if GenerateDecoderOnlyOutput not available in this transformers version + # `sequences` expects a tensor of token ids + return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type] else: - return embeds + return result # Give access to all weights as properties. @property diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 5f026f493..0d144a741 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from better_abc import abstract_attribute from jaxtyping import Float, Int +from torch import Tensor from transformers.utils.import_utils import is_bitsandbytes_available from transformer_lens.cache.key_value_cache_entry import ( @@ -280,8 +281,7 @@ def forward( raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}") pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - pattern = pattern.to(v.device) + pattern = pattern.to(device=v.device, dtype=v.dtype) z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: if self.cfg.load_in_4bit: @@ -301,15 +301,21 @@ def forward( self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" ) - if self.b_O.device != w.device: - w = w.to(self.b_O.device) - if self.b_O.device != z.device: - z = z.to(self.b_O.device) + # Move output projection weights and bias to the same device as z + # so that the final linear operation occurs on the device of the inputs + if w.device != z.device: + w = w.to(z.device) + b_O: Tensor = self.b_O + if b_O.device != z.device: + b_O = b_O.to(z.device) + # Ensure z has the same dtype as weights used in the output projection + if z.dtype != w.dtype: + z = z.to(w.dtype) out = F.linear( z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), w, - self.b_O, + b_O, ) else: # Explicitly calculate the attention result so it can be accessed by a hook @@ -329,6 +335,11 @@ def forward( self.W_O, "head_index d_head d_model -> 1 1 head_index d_head d_model", ) + if w.device != z.device: + w = w.to(z.device) + # Ensure z has the same dtype as w before multiplication + if z.dtype != w.dtype: + z = z.to(w.dtype) z = einops.rearrange( z, "batch pos head_index d_head -> batch pos head_index d_head 1" ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 141af5e2e..ae44f294d 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -39,6 +39,7 @@ ) from transformer_lens.model_bridge.get_params_util import get_bridge_params from transformer_lens.utilities.aliases import resolve_alias +from transformer_lens.utilities.devices import move_to_and_update_config if TYPE_CHECKING: from transformer_lens.ActivationCache import ActivationCache @@ -1621,7 +1622,10 @@ def generate( padding_side: Optional[str] = None, return_type: Optional[str] = "input", verbose: bool = True, - ) -> Union[str, List[str], torch.Tensor]: + output_logits: bool = False, + ) -> str | list[str] | torch.Tensor | Any: # Any for transformers.utils.ModelOutput + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 """Sample tokens from the model. Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. @@ -1642,9 +1646,11 @@ def generate( padding_side: Not used in Bridge (kept for API compatibility) return_type: The type of output to return - 'input', 'str', or 'tokens' verbose: Not used in Bridge (kept for API compatibility) + output_logits: If True, return a ModelOutput with sequences and logits tuple Returns: - Generated sequence as string, list of strings, or tensor depending on input type and return_type + Generated sequence as string, list of strings, or tensor depending on input type and return_type. + If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ # Convert input to tokens if isinstance(input, str): @@ -1694,6 +1700,9 @@ def generate( # Track which sequences have finished finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + # Optionally collect logits at each generation step for downstream tooling/tests + logits_seq_list: list[torch.Tensor] | None = [] if output_logits else None + # Generate tokens current_tokens = input_tokens.clone() sampled_tokens_list = [] @@ -1704,6 +1713,10 @@ def generate( logits = self(current_tokens, return_type="logits") final_logits = logits[:, -1, :] + # Collect logits if requested + if logits_seq_list is not None: + logits_seq_list.append(final_logits.clone()) + # Sample next token if do_sample: sampled_tokens = utils.sample_logits( @@ -1740,6 +1753,33 @@ def generate( sampled_tokens = torch.cat(sampled_tokens_list, dim=1) output_tokens = torch.cat([input_tokens, sampled_tokens], dim=1) + # Return ModelOutput if output_logits was requested + if output_logits and logits_seq_list is not None: + from transformers.utils import ModelOutput # type: ignore + + def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + assert logits_list is not None + # Convert list of [batch, vocab] tensors to tuple + return tuple(logits_list) + + try: + from transformers.generation.utils import GenerateDecoderOnlyOutput + + # Return a HF-compatible ModelOutput structure + # GenerateDecoderOnlyOutput expects: sequences, scores (optional), logits (optional) + return GenerateDecoderOnlyOutput( + sequences=cast(torch.LongTensor, output_tokens), + # HF's type hint says tuple[FloatTensor] but should be tuple[FloatTensor, ...] + # (variable-length tuple with one element per generated token) + logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] + ) + except (ImportError, AttributeError): + # Fallback if GenerateDecoderOnlyOutput not available in this transformers version + return ModelOutput( + sequences=output_tokens, + logits=_logits_to_tuple(logits_seq_list), + ) + # Format output if return_type == "str": if input_type == "str": @@ -1753,16 +1793,220 @@ def generate( else: # return_type == "tokens" return output_tokens + def hf_generate( + self, + input: str | list[str] | torch.Tensor = "", + max_new_tokens: int = 10, + stop_at_eos: bool = True, + eos_token_id: int | None = None, + do_sample: bool = True, + top_k: int | None = None, + top_p: float | None = None, + temperature: float = 1.0, + use_past_kv_cache: bool = True, + return_type: str | None = "input", + **generation_kwargs, + ) -> str | list[str] | torch.Tensor | Any: # Any for HF ModelOutput types + # Using Any due to beartype's forward reference resolution limitations. + # See: https://github.com/beartype/beartype/issues/546 + """Generate text using the underlying HuggingFace model with full HF API support. + + This method provides direct access to HuggingFace's generation API, forwarding all + generation parameters (including output_scores, output_logits, output_attentions, + output_hidden_states) directly to the underlying HF model. Use this when you need + full HuggingFace generation features not supported by the standard generate() method. + + For standard generation compatible with HookedTransformer, use generate() instead. + + Args: + input: Text string, list of strings, or tensor of tokens + max_new_tokens: Maximum number of tokens to generate + stop_at_eos: If True, stop generating tokens when the model outputs eos_token + eos_token_id: The token ID to use for end of sentence + do_sample: If True, sample from the model's output distribution + top_k: Number of tokens to sample from + top_p: Probability mass to sample from + temperature: Temperature for sampling + use_past_kv_cache: If True, use KV caching for faster generation + return_type: The type of output to return - 'input', 'str', or 'tokens' + **generation_kwargs: Additional HuggingFace generation parameters including: + - output_scores: Return generation scores + - output_logits: Return generation logits + - output_attentions: Return attention weights + - output_hidden_states: Return hidden states + - return_dict_in_generate: Return ModelOutput object + - And any other HF generation parameters + + Returns: + Generated sequence as string, list of strings, tensor, or HF ModelOutput + depending on input type, return_type, and generation_kwargs. + + Example:: + + # Get full HF ModelOutput with logits and attentions + from transformer_lens import HookedTransformer + model = HookedTransformer.from_pretrained("tiny-stories-1M") + result = model.hf_generate( + "Hello world", + max_new_tokens=5, + output_logits=True, + output_attentions=True, + return_dict_in_generate=True + ) + print(result.sequences) # Generated tokens + print(result.logits) # Logits for each generation step + print(result.attentions) # Attention weights + """ + # Handle string input by tokenizing it + if isinstance(input, str): + inputs = self.tokenizer(input, return_tensors="pt", padding=False, truncation=False).to( + self.cfg.device + ) + input_ids = inputs["input_ids"] + input_type = "str" + elif isinstance(input, list): + inputs = self.tokenizer(input, return_tensors="pt", padding=True, truncation=False).to( + self.cfg.device + ) + input_ids = inputs["input_ids"] + input_type = "list" + else: + input_ids = input + if input_ids.device != self.cfg.device: + input_ids = input_ids.to(self.cfg.device) + input_type = "tokens" + + # Build generation_kwargs from explicit args and kwargs + generation_kwargs = dict(generation_kwargs) if generation_kwargs is not None else {} + generation_kwargs.update( + { + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + "temperature": temperature, + "pad_token_id": self.tokenizer.eos_token_id, + } + ) + + if top_k is not None: + generation_kwargs["top_k"] = top_k + if top_p is not None: + generation_kwargs["top_p"] = top_p + if eos_token_id is not None: + generation_kwargs["eos_token_id"] = eos_token_id + elif stop_at_eos and self.tokenizer.eos_token_id is not None: + generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id + + if use_past_kv_cache: + generation_kwargs["use_cache"] = True + + # HF dict flags that trigger ModelOutput returns + hf_dict_flags = ( + "output_scores", + "output_logits", + "output_attentions", + "output_hidden_states", + ) + + # If any HF-style output flags are provided, ensure return_dict_in_generate is set + any_flag_set = False + for f in hf_dict_flags: + if generation_kwargs.get(f) is not None: + generation_kwargs[f] = bool(generation_kwargs[f]) + any_flag_set = True + + if any_flag_set: + generation_kwargs.setdefault("return_dict_in_generate", True) + + # Generate using the original HuggingFace model + with torch.no_grad(): + outputs = self.original_model.generate(input_ids, **generation_kwargs) # type: ignore[operator] + + # Check if output is a ModelOutput + try: + from transformers.utils import ModelOutput # type: ignore + + is_model_output = isinstance(outputs, ModelOutput) + except Exception: + is_model_output = False + + # Return based on return_type and input format + if return_type == "input" or return_type is None: + if input_type == "str": + # Decode the full output back to string + if is_model_output and hasattr(outputs, "sequences"): + return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + elif input_type == "list": + # Decode each sequence in the batch + if is_model_output and hasattr(outputs, "sequences"): + return [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in outputs.sequences + ] + return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] + else: + # Return the full token sequence including input + return outputs + elif return_type == "tokens": + return outputs + else: + # For other return types, default to the decoded text + if input_type == "str": + if is_model_output and hasattr(outputs, "sequences"): + return self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + elif input_type == "list": + if is_model_output and hasattr(outputs, "sequences"): + return [ + self.tokenizer.decode(seq, skip_special_tokens=True) + for seq in outputs.sequences + ] + return [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in outputs] + else: + return outputs + def to(self, *args, **kwargs) -> "TransformerBridge": - """Move model to device or change dtype. + """Move model to device and/or change dtype. Args: args: Positional arguments for nn.Module.to kwargs: Keyword arguments for nn.Module.to + print_details: Whether to print details about device/dtype changes (default: True) Returns: Self for chaining """ + # Extract print_details if provided + print_details = kwargs.pop("print_details", True) + + # Handle both device and dtype changes + # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype), + # to(device=...), to(dtype=...), to(device=..., dtype=...) + target_device, target_dtype = None, None + + if len(args) >= 1: + first_arg = args[0] + if isinstance(first_arg, (torch.device, str)): + target_device = first_arg + elif isinstance(first_arg, torch.dtype): + target_dtype = first_arg + if len(args) >= 2: + second_arg = args[1] + if isinstance(second_arg, torch.dtype): + target_dtype = second_arg + + # these override positional args + if "device" in kwargs: + target_device = kwargs["device"] + if "dtype" in kwargs: + target_dtype = kwargs["dtype"] + + if target_device is not None: + move_to_and_update_config(self, target_device, print_details) + if target_dtype is not None: + move_to_and_update_config(self, target_dtype, print_details) + + # Move the original model with all original args/kwargs (with print_details removed) self.original_model = self.original_model.to(*args, **kwargs) return self diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index b97ad95e6..77566f51b 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -122,8 +122,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: first = output[0] if isinstance(first, torch.Tensor): first = self.hook_out(first) + # Always return tuple to maintain consistency with HF's expected format + # e.g. GPT2Model does hidden_states = outputs[0], it expects outputs to be a tuple if len(output) == 1: - return first + return (first,) output = (first,) + output[1:] return output if isinstance(output, torch.Tensor):