Skip to content

brotSchimmelt/LLM-MCTS-Inference

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LLM MCTS Inference

Tests Publish PyPI

An experimental project using Monte Carlo Tree Search (MCTS) to refine Language Model (LLM) responses for better accuracy and decision-making.

Overview

This project leverages MCTS to explore multiple answer candidates generated by an LLM. By iteratively generating an initial answer, evaluating it, and refining it based on targeted self-feedback, the system strives to improve response quality and decision-making. This approach leverages test-time compute to enhance the precision and robustness of model outputs.

MCTS Inference Process

The process follows these key steps:

  • Initial Answer Generation: Uses greedy decoding to generate an initial response.
  • Feedback Generation: Provides constructive, concise feedback on initial answers. The feedback is generated by the model itself.
  • Iterative Refinement: Refines responses based on the feedback through additional model queries.
  • Monte Carlo Tree Search: Employs MCTS to explore and evaluate multiple answer paths.

Experimental Results

The performance of this approach was evaluated on a subset of the GSM8k test split using the Llama3.2-1B-instruct model with vLLM. A baseline run using zero-shot prompting achieved a pass@8 score of 74% and a majority@8 score of 27%. When applying MCTS for iterative refinement, the pass@8 score marginally increased to 75%, while the majority@8 score improved significantly to 39%. The evaluation was done with llm-eval.

These results suggest that while MCTS does not drastically improve the probability of generating at least one correct answer (pass@8), it significantly enhances response consistency (majority@8), making the model more reliable in decision-making scenarios.

Why Llama3.2-1B-instruct?

A smaller model was selected for this experiment to better illustrate the impact of MCTS. Larger models already achieve high accuracy on GSM8k, making it difficult to demonstrate meaningful improvements. The 1B parameter model provides a more realistic proof-of-concept by: • Being resource-efficient, allowing for scalable experimentation. • Providing a challenging test case, as smaller models struggle more with GSM8k, making improvements more noticeable. • Ensuring the evaluation remains relevant, since GSM8k has been extensively benchmarked by larger models, leaving little room for additional gains.

Installation

Dependencies

  • Python: Version 3.11 or higher

The project depends mainly on the following packages:

  • instructor for guided generation
  • litellm provides a unified API to interact with multiple LLM providers

Setup Instructions

To install the package directly from PyPi run the following command: pip install llm-mcts-inference

To install from source, follow these steps:

  1. Clone the Repository:

    git clone https://github.com/brotSchimmelt/llm-mcts-inference.git
    cd llm-mcts-inference
  2. Install the Project Dependencies:

    If you use uv, run the following commands to create a virtualenv and install all requirements:

    uv venv --python 3.11
    uv sync

    Otherwise, install the required packages with pip:

    pip install -r pyproject.toml
  3. Configure Environment Variables: Rename the provided example.env file to .env and update it with your API keys or other configuration details as needed.

Usage

Use the MonteCarloLLM class to generate and improve responses via MCTS:

from llm_mcts_inference.MonteCarloLLM import MonteCarloLLM

# Initialize with a specific model; defaults are defined in settings
llm = MonteCarloLLM(model_name="openai/gpt-4o-mini")

# Define your prompt
prompt = "What is the capital of France?"

# Generate a response using Monte Carlo Tree Search
result = llm.generate(prompt=prompt, iterations=5, max_children=3)

# Output the final improved answer
print("Final Answer:", result.answer)

# Optionally, display the sequence of nodes (answers) along the best path
print("Best Path:", [node.answer for node in result.valid_path])

License

This project is licensed under the MIT license.

About

An experimental project using MCTS to refine LLM responses for better accuracy and decision-making.

Topics

Resources

License

Stars

Watchers

Forks

Contributors

Languages