Skip to content
Merged
32 changes: 32 additions & 0 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,38 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
assert output_tf == output_hf_str


def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream():
tf_model = HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)

hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

final_output = ""
for result in tf_model.generate_stream(
text,
do_sample=False,
use_past_kv_cache=True,
verbose=False,
max_new_tokens=10,
max_tokens_per_yield=10,
):
final_output += tf_model.to_string(result[0])

hf_input_ids = hf_tokenizer(text, return_tensors="pt").input_ids
output_hf_tokens = hf_model.generate(
hf_input_ids,
do_sample=False,
max_new_tokens=10,
)
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)

assert (
final_output == output_hf_str
), f"\nStreaming output: {final_output}\nHF output: {output_hf_str}"


def check_norm_folding(
model_name,
hf_model=None,
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedEncoderDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,13 @@ def generate(
else:
return decoder_input

@overload
@overload # type: ignore[overload-overlap]
def run_with_cache(
self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
...

@overload
@overload # type: ignore[overload-overlap]
def run_with_cache(
self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
Expand Down
227 changes: 227 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import logging
import os
from collections.abc import Generator
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -2203,6 +2204,232 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...
else:
return result

@torch.inference_mode()
def generate_stream(
self,
input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
max_new_tokens: int = 10,
max_tokens_per_yield: int = 25,
stop_at_eos: bool = True,
eos_token_id: Optional[int] = None,
do_sample: bool = True,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 1.0,
freq_penalty: float = 0.0,
use_past_kv_cache: bool = True,
prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
return_type: Optional[str] = "input",
verbose: bool = True,
) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]:
"""Stream tokens from the Model as they are generated.

Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached,
yielding batches of tokens progressively during generation rather than waiting for the entire
sequence to be generated.

To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
(by producing an EOT token), we keep running the model on the entire batch, but throw away
the output for a finished sequence and just keep adding EOTs to pad.

This supports entering a single string, but not a list of strings - if the strings don't
tokenize to exactly the same length, this gets messy. If that functionality is needed,
convert them to a batch of tokens and input that instead.

Args:
input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
pos]) or a text string (this will be converted to a batch of tokens with batch size
1).
max_new_tokens (int): Maximum number of tokens to generate.
max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding.
Controls how frequently the function yields tokens during generation.
stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
of sentence. If None, use the tokenizer's eos_token_id - required if using
stop_at_eos. It's also possible to provide a list of token IDs (not just the
eos_token_id), in which case the generation will stop when any of them are output
(useful e.g. for stable_lm).
do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
greedy search (take the max logit each time).
top_k (int): Number of tokens to sample from. If None, sample from all tokens.
top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
we take the top tokens with cumulative probability >= top_p.
temperature (float): Temperature for sampling. Higher values will make the model more
random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
sampling from a uniform distribution).
freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
tokens. Higher values will make the model more random.
use_past_kv_cache (bool): If True, create and use cache to speed up generation.
prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
the BOS token to the input (applicable when input is a string). Defaults to None,
implying usage of self.cfg.default_prepend_bos (default is True unless specified
otherwise). Pass True or False to override the default.
padding_side (Union[Literal["left", "right"], None], optional): Overrides
self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
strings of different lengths.
return_type (Optional[str]): The type of the output to return - either a string (str),
a tensor of tokens (tensor) or whatever the format of the input was (input).
verbose (bool): If True, show tqdm progress bars for generation.

Yields:
outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded
progressively during generation. Each yield contains accumulated tokens since the last
yield, up to max_tokens_per_yield.
"""

with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
if type(input) == str:
# If text, convert to tokens (batch_size=1)
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string"
tokens = input

if return_type == "input":
if type(input) == str:
return_type = "str"
else:
return_type = "tensor"

assert isinstance(tokens, torch.Tensor)
batch_size, ctx_length = tokens.shape
device = get_device_for_block_index(0, self.cfg)
tokens = tokens.to(device)
if use_past_kv_cache:
past_kv_cache = TransformerLensKeyValueCache.init_cache(
self.cfg, self.cfg.device, batch_size
)
else:
past_kv_cache = None

stop_tokens: List[int] = []
eos_token_for_padding = 0
assert self.tokenizer is not None
if stop_at_eos:
tokenizer_has_eos_token = (
self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
tokenizer_has_eos_token
), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"

eos_token_id = self.tokenizer.eos_token_id

if isinstance(eos_token_id, int):
stop_tokens = [eos_token_id]
eos_token_for_padding = eos_token_id
else:
# eos_token_id is a Sequence (e.g. list or tuple)
stop_tokens = eos_token_id
eos_token_for_padding = (
self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
)

# An array to track which sequences in the batch have finished.
finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)

accumulated_tokens: Optional[torch.Tensor] = None
tokens_since_last_yield = 0

# Currently nothing in HookedTransformer changes with eval, but this is here in case
# that changes in the future.
self.eval()
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
# While generating, we keep generating logits, throw away all but the final logits,
# and then use those logits to sample from the distribution We keep adding the
# sampled tokens to the end of tokens.
if use_past_kv_cache:
# We just take the final tokens, as a [batch, 1] tensor
if index > 0:
logits = self.forward(
tokens[:, -1:],
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
)
else:
logits = self.forward(
tokens,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
)
else:
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
# the cache.
logits = self.forward(
tokens,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
)
final_logits = logits[:, -1, :]

if do_sample:
sampled_tokens = utils.sample_logits(
final_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
freq_penalty=freq_penalty,
tokens=tokens,
).to(get_device_for_block_index(0, self.cfg))
else:
sampled_tokens = final_logits.argmax(-1).to(
get_device_for_block_index(0, self.cfg)
)

if stop_at_eos:
# For all unfinished sequences, add on the next token. If a sequence was
# finished, throw away the generated token and add eos_token_for_padding
# instead.
sampled_tokens[finished_sequences] = eos_token_for_padding
finished_sequences.logical_or_(
torch.isin(
sampled_tokens.to(self.cfg.device),
torch.tensor(stop_tokens).to(self.cfg.device),
)
)

new_tokens = sampled_tokens.unsqueeze(-1)

# Accumulate tokens until we hit max_tokens_per_yield
if index == 0:
accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1)
tokens_since_last_yield = accumulated_tokens.shape[1]
else:
if accumulated_tokens is None:
accumulated_tokens = new_tokens
else:
accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
tokens_since_last_yield += 1

if tokens_since_last_yield >= max_tokens_per_yield:
yield accumulated_tokens
tokens_since_last_yield = 0
accumulated_tokens = None

tokens = torch.cat([tokens, new_tokens], dim=-1)

if stop_at_eos and finished_sequences.all():
# Yield any remaining accumulated tokens before breaking
if accumulated_tokens is not None:
yield accumulated_tokens
break

# Only yield remaining tokens if we didn't already yield them in the break case
if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()):
yield accumulated_tokens

# Give access to all weights as properties.
@property
def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
Expand Down
Loading