Skip to content

Conversation

@kiranandcode
Copy link
Contributor

This PR builds on top of #412. Often we would like to use an LLM to generate types that can not be hashed, or whose outputs do not lie within a small enough set such that voting is likely to complete in any reasonable amount of time (think generating anything with a string component, or callables). Yet these outputs may still have some "essense" in a smaller set which would still provide useful signal for voting.

This PR implements a ReducedKAheadSampler which allows the user to provide a reducer function which maps LLM generated outputs to a reduced key-set for voting, and then runs the usual k-ahead sampling algorithm.

class MovieGenre(str, Enum):
    ACTION = "action"
    COMEDY = "comedy"
    DRAMA = "drama"
    HORROR = "horror"
    SCIFI = "sci-fi"
    ROMANCE = "romance"

@dataclass(frozen=True)
class MovieClassification:
    genre: MovieGenre
    explanation: str = Field(..., description="explanation for the given movie classification")

@Template.define
def classify_genre(plot: str) -> MovieClassification:
    """Classify the movie genre based on this plot: {plot}"""
    raise NotImplementedError


plot = "A rogue cop must stop a terrorist group from detonating bombs across the city."

with handler(LLMLoggingHandler()), handler(OpenAIAPIProvider(openai.OpenAI())):

    def reducer(classification: MovieClassification) -> MovieGenre:
        return classification.genre
    def select_best(results: list[MovieClassification]) -> MovieClassification:
        return max(results, key=lambda res: len(res.explanation))
    
    sampler = ReducedKAheadSampler(reducer, select_best=select_best)
    classification = handler(sampler)(classify_genre)(plot)
    print(classification)

@kiranandcode
Copy link
Contributor Author

Orthogonal design consideration: I don't feel like KAheadSampler as an interpretation for Template.__call__ really makes sense. In particular, because it overrides the semantics for Template.__call__ for all Template.calls when the algorithm only works for one particular Template.__call__.

@jfeser
Copy link
Contributor

jfeser commented Nov 25, 2025

I think the fact that we need separate Template.__call__ handlers for this and the previous voting algorithm and that handling Template.__call__ is too coarse suggests that we should build simpler batching functionality and implement voting on top.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants