Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pr-checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ permissions:

jobs:
run_code_quality:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

Expand Down Expand Up @@ -50,7 +50,7 @@ jobs:
shell: bash -el {0}

run_pytest:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

RadFact is a framework for the evaluation of model-generated radiology reports given a ground-truth report, **with or without grounding**. Leveraging the logical inference capabilities of large language models, RadFact is not a single number but a _suite_ of metrics, capturing aspects of precision and recall at text-only and text-and-grounding levels.

RadFact was introduced in [MAIRA-2: Grounded Radiology Report Generation](https://aka.ms/maira-2). Here we provide an open-source implementation of the metric to facilitate its use and development.
RadFact was introduced in [MAIRA-2: Grounded Radiology Report Generation](https://aka.ms/maira-2). Here we provide an open-source implementation of the metric to facilitate its use and development. The RadFact metric currently supports both `cxr` and `ct` report types.

## Table of Contents

Expand Down Expand Up @@ -157,6 +157,8 @@ options:
Path to the directory where the results will be saved as a json file.
--bootstrap_samples BOOTSTRAP_SAMPLES
Number of bootstrap samples to use for computing the confidence intervals. Set to 0 to disable bootstrapping.
--report_type {cxr,ct}
Type of report: 'cxr' for chest x-ray reports or 'ct' for CT reports.
```

- for non-grounded reports (findings generation narrative text):
Expand All @@ -177,6 +179,12 @@ The script computes confidence intervals for the metrics using bootstrapping. Th

⚠️**WARNING**: Some queries may fail due to the endpoint limitations (timeouts, rate limits, etc.). When the LLM performing entailment verification fails, we **set these examples as not-entailed by default**. If this occurs in a significant number of cases, the results will not be reliable. The final metrics dict contains the number of such skipped queries under the key `num_llm_failures`. The script will print the number of skipped queries at the end of the run, and store these in the `skipped` directroy under the run id folder. You will also see a warning message in the logs for each failed query. `WARNING: No response for example {query_id}. Setting as NOT ENTAILED`.

### Supporting Multiple Report Rypes
RadFact supports different report types through the `report_type` field in the `RadFactMetric` class. Currently supported options are:

- `cxr` - Chest X-ray reports (default)
- `ct` - CT scan reports

### Split reports into phrases

We also provide a script to convert reports to phrases. This is useful when you have a narrative report and want to convert it to a list of phrases for RadFact evaluation. You can run this step offline and then use the output file as input to RadFact. Make sure you've set up the endpoints as described above before running the script. The `run_report_to_phrases` command runs `python src/radfact/cli/run_report_to_phrases.py` script under the hood.
Expand Down
1 change: 0 additions & 1 deletion dev_environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ dependencies:
- pyparsing==3.2.0
- pysocks==1.7.1
- pytest==8.3.3
- pytest-lazy-fixture==0.6.3
- python-dateutil==2.9.0.post0
- pytz==2024.2
- pyyaml==6.0.2
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ test = [
"mock",
"pandas-stubs",
"pytest",
"pytest-lazy-fixture",
]

[project.urls]
Expand Down
12 changes: 12 additions & 0 deletions src/radfact/cli/run_radfact.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pandas as pd

from radfact.llm_utils.prompt_tasks import ReportType
from radfact.data_utils.grounded_phrase_list import GroundedPhraseList
from radfact.llm_utils.report_to_phrases.processor import StudyIdType
from radfact.metric.bootstrapping import MetricBootstrapper
Expand Down Expand Up @@ -66,12 +67,14 @@ def compute_radfact_scores(
candidates: InputDict,
references: InputDict,
is_narrative_text: bool,
report_type: ReportType,
bootstrap_samples: int,
) -> dict[str, float]:
radfact_metric = RadFactMetric(
nli_config_name=radfact_config_name,
phrase_config_name=phrases_config_name,
is_narrative_text=is_narrative_text,
report_type=report_type,
)
if bootstrap_samples == 0:
_, results = radfact_metric.compute_metric_score(candidates, references)
Expand Down Expand Up @@ -131,6 +134,13 @@ def main() -> None:
"bootstrapping.",
default=500,
)
parser.add_argument(
"--report_type",
type=str,
choices=["cxr", "ct"],
help="Type of report: 'cxr' for chest x-ray reports or 'ct' for CT reports.",
default="cxr",
)

args = parser.parse_args()
input_path = Path(args.input_path)
Expand All @@ -139,6 +149,7 @@ def main() -> None:
radfact_config_name = args.radfact_config_name
phrases_config_name = args.phrases_config_name
bootstrap_samples = args.bootstrap_samples
report_type = ReportType(args.report_type)

assert input_path.suffix in [".csv", ".json"], "Input file must be a csv or json file."
assert input_path.suffix == ".csv" or not is_narrative_text, (
Expand All @@ -163,6 +174,7 @@ def main() -> None:
references=references,
is_narrative_text=is_narrative_text,
bootstrap_samples=bootstrap_samples,
report_type=report_type,
)

print_fn = print_results if bootstrap_samples == 0 else print_bootstrap_results
Expand Down
35 changes: 21 additions & 14 deletions src/radfact/llm_utils/nli/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable

import pandas as pd
from radfact.llm_utils.prompt_tasks import NLITaskOptions, ReportType
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage
from omegaconf import DictConfig
Expand All @@ -35,11 +36,9 @@
StructuredProcessor,
simple_formatter,
)
from radfact.paths import OUTPUT_DIR, get_prompts_dir
from radfact.paths import OUTPUT_DIR

logger = logging.getLogger(__name__)
PARSING_TASK = "nli"
PROMPTS_DIR = get_prompts_dir(task=PARSING_TASK)
RADFACT_SUBFOLDER = "radfact"


Expand All @@ -49,22 +48,22 @@ class MetricDataframeKeys(str, Enum):
STUDY_ID = "study_id"


def get_ev_processor_singlephrase(log_dir: Path) -> StructuredProcessor[ComparisonQuerySinglePhrase, EvidencedPhrase]:
def get_ev_processor_singlephrase(
report_type: ReportType, log_dir: Path
) -> StructuredProcessor[ComparisonQuerySinglePhrase, EvidencedPhrase]:
"""
Helper function to load the NLI processor with the correct system prompt and few-shot examples.

The setting here is to classify a SINGLE PHRASE at a time given the reference report.
Further, we do entailment verification, aka the binary version of NLI.

:param api_arguments: API arguments for the LLM.
:param report_type: The type of report, e.g., "ReportType.CXR" or "ReportType.CT".
:param log_dir: Directory to save logs.
:return: Processor for entailment verification.
"""

system_prompt_path = PROMPTS_DIR / "system_message_ev_singlephrase.txt"
few_shot_examples_path = PROMPTS_DIR / "few_shot_examples.json"
system_prompt = system_prompt_path.read_text()
few_shot_examples = load_examples_from_json(json_path=few_shot_examples_path, binary=True)
task = NLITaskOptions[report_type.name].value
system_prompt = task.system_message_path.read_text()
few_shot_examples = load_examples_from_json(json_path=task.few_shot_examples_path, binary=True)
# The few-shots are in the bidirectional format, we need to convert them to single-phrase.
few_shot_examples_single_phrase: list[NLISampleSinglePhrase] = []
for few_shot_example in few_shot_examples:
Expand Down Expand Up @@ -94,10 +93,12 @@ class ReportGroundingNLIProcessor(BaseProcessor[NLIQuerySample, NLISample]):
NUM_LLM_SUCCESS = "num_llm_success"
NUM_LLM_PHRASE_REWRITES = "num_llm_phrase_rewrites"

def __init__(self, format_query_fn: Callable[..., Any] | None = None) -> None:
def __init__(self, report_type: ReportType, format_query_fn: Callable[..., Any] | None = None) -> None:
super().__init__()
self.format_query_fn = format_query_fn
self.phrase_processor = get_ev_processor_singlephrase(log_dir=OUTPUT_DIR / "ev_processor_logs")
self.phrase_processor = get_ev_processor_singlephrase(
report_type=report_type, log_dir=OUTPUT_DIR / "ev_processor_logs"
)
# Logging errors
self.num_llm_failures = 0
self.num_llm_success = 0
Expand Down Expand Up @@ -187,10 +188,16 @@ def format_row_to_nli_query_sample(row: "pd.Series[Any]") -> NLIQuerySample:


def get_report_nli_engine(
cfg: DictConfig, candidates: dict[str, GroundedPhraseList], references: dict[str, GroundedPhraseList]
cfg: DictConfig,
candidates: dict[str, GroundedPhraseList],
references: dict[str, GroundedPhraseList],
report_type: ReportType = ReportType.CXR,
) -> LLMEngine:
output_folder = get_subfolder(root=OUTPUT_DIR, subfolder=RADFACT_SUBFOLDER)
nli_report_processor = ReportGroundingNLIProcessor(format_query_fn=format_row_to_nli_query_sample)

nli_report_processor = ReportGroundingNLIProcessor(
report_type=report_type, format_query_fn=format_row_to_nli_query_sample
)
dataset_df = pd.DataFrame(
{
MetricDataframeKeys.STUDY_ID: study_id,
Expand Down
Loading
Loading