diff --git a/README.md b/README.md index 63c76b8..03fe2c9 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,8 @@ You can refer to the [getting_started](getting_started.ipynb) notebook to see ho ```bash $ run_radfact --help -usage: run_radfact [-h] [--radfact_config_name RADFACT_CONFIG_NAME] [--phrases_config_name PHRASES_CONFIG_NAME] --input_path INPUT_PATH [--is_narrative_text] [--output_dir OUTPUT_DIR] [--bootstrap_samples BOOTSTRAP_SAMPLES] +usage: run_radfact [-h] --input_path INPUT_PATH [--is_narrative_text] [--radfact_config_name RADFACT_CONFIG_NAME] [--phrases_config_name PHRASES_CONFIG_NAME] [--filtering_config_name FILTERING_CONFIG_NAME] [--output_dir OUTPUT_DIR] +[--bootstrap_samples BOOTSTRAP_SAMPLES] [--report_type {cxr,ct}] [--filter_negatives] Compute RadFact metric for a set of samples and saves the results to a json file. @@ -153,12 +154,15 @@ options: The name of the config file for reports to phrases conversion. We use the default config file but you can provide a custom config. Make sure the config follows the same structure as `configs/report_to_phrases.yaml` and is saved in the `configs` directory. This is necessary for hydra initialization from the `configs` directory. + --filtering_config_name FILTERING_CONFIG_NAME + The name of the config file for negative finding filtering. We use the default config file but you can provide a custom config. Make sure the config follows the same structure as `configs/negative_filtering.yaml` and is saved in the `configs` directory. This is necessary for hydra initialization from the `configs` directory. --output_dir OUTPUT_DIR Path to the directory where the results will be saved as a json file. --bootstrap_samples BOOTSTRAP_SAMPLES Number of bootstrap samples to use for computing the confidence intervals. Set to 0 to disable bootstrapping. --report_type {cxr,ct} Type of report: 'cxr' for chest x-ray reports or 'ct' for CT reports. + --filter_negatives Whether to filter negative findings from the parsed reports before computing the RadFact score. ``` - for non-grounded reports (findings generation narrative text): @@ -179,7 +183,7 @@ The script computes confidence intervals for the metrics using bootstrapping. Th ⚠️**WARNING**: Some queries may fail due to the endpoint limitations (timeouts, rate limits, etc.). When the LLM performing entailment verification fails, we **set these examples as not-entailed by default**. If this occurs in a significant number of cases, the results will not be reliable. The final metrics dict contains the number of such skipped queries under the key `num_llm_failures`. The script will print the number of skipped queries at the end of the run, and store these in the `skipped` directroy under the run id folder. You will also see a warning message in the logs for each failed query. `WARNING: No response for example {query_id}. Setting as NOT ENTAILED`. -### Supporting Multiple Report Rypes +### Supporting Multiple Report Types RadFact supports different report types through the `report_type` field in the `RadFactMetric` class. Currently supported options are: - `cxr` - Chest X-ray reports (default) @@ -195,6 +199,13 @@ We also provide a script to convert reports to phrases. This is useful when you This script is configurable using the `report_to_phrases.yaml` config file. You can specify the input file, output file, and the endpoint to use for the conversion. +### Filtering Negative Phrases +Radiology reports can have a disproportionate number of negative findings, and filtering these out can help focus evaluation on clinically relevant positive findings. For non-grounded reports, RadFact can be configured to filter out negative findings once reports have been converted to phrases. Note that this feature is currently only available for CT reports. + +```bash + run_radfact --input_path --is_narrative_text --filter_negatives +``` + ## What is RadFact? ![Illustration of RadFact](RadFact.png "Illustration of RadFact") diff --git a/configs/negative_filtering.yaml b/configs/negative_filtering.yaml new file mode 100644 index 0000000..0d70080 --- /dev/null +++ b/configs/negative_filtering.yaml @@ -0,0 +1,9 @@ +#@package __global__ + +defaults: + - default + - override endpoints: azure_chat_openai + - _self_ + +processing: + index_col: sentence_id \ No newline at end of file diff --git a/src/radfact/cli/run_radfact.py b/src/radfact/cli/run_radfact.py index f2140e9..55e2151 100644 --- a/src/radfact/cli/run_radfact.py +++ b/src/radfact/cli/run_radfact.py @@ -64,17 +64,21 @@ def get_candidates_and_references_from_json( def compute_radfact_scores( radfact_config_name: str | None, phrases_config_name: str | None, + filtering_config_name: str | None, candidates: InputDict, references: InputDict, is_narrative_text: bool, report_type: ReportType, bootstrap_samples: int, + filter_negatives: bool, ) -> dict[str, float]: radfact_metric = RadFactMetric( nli_config_name=radfact_config_name, phrase_config_name=phrases_config_name, + filtering_config_name=filtering_config_name, is_narrative_text=is_narrative_text, report_type=report_type, + filter_negatives=filter_negatives, ) if bootstrap_samples == 0: _, results = radfact_metric.compute_metric_score(candidates, references) @@ -121,6 +125,15 @@ def main() -> None: "initialization from the `configs` directory.", default=None, ) + parser.add_argument( + "--filtering_config_name", + type=str, + help="The name of the config file for negative finding filtering. We use the default config file but you can " + "provide a custom config. Make sure the config follows the same structure as `configs/negative_filtering.yaml` " + "and is saved in the `configs` directory. This is necessary for hydra initialization from the `configs` " + "directory.", + default=None, + ) parser.add_argument( "--output_dir", type=str, @@ -141,6 +154,11 @@ def main() -> None: help="Type of report: 'cxr' for chest x-ray reports or 'ct' for CT reports.", default="cxr", ) + parser.add_argument( + "--filter_negatives", + action="store_true", + help="Whether to filter negative findings from the parsed reports before computing the RadFact score.", + ) args = parser.parse_args() input_path = Path(args.input_path) @@ -148,8 +166,10 @@ def main() -> None: is_narrative_text = args.is_narrative_text radfact_config_name = args.radfact_config_name phrases_config_name = args.phrases_config_name + filtering_config_name = args.filtering_config_name bootstrap_samples = args.bootstrap_samples report_type = ReportType(args.report_type) + filter_negatives = args.filter_negatives assert input_path.suffix in [".csv", ".json"], "Input file must be a csv or json file." assert input_path.suffix == ".csv" or not is_narrative_text, ( @@ -170,11 +190,13 @@ def main() -> None: results = compute_radfact_scores( radfact_config_name=radfact_config_name, phrases_config_name=phrases_config_name, + filtering_config_name=filtering_config_name, candidates=candidates, references=references, is_narrative_text=is_narrative_text, bootstrap_samples=bootstrap_samples, report_type=report_type, + filter_negatives=filter_negatives, ) print_fn = print_results if bootstrap_samples == 0 else print_bootstrap_results diff --git a/src/radfact/llm_utils/negative_filtering/__init__.py b/src/radfact/llm_utils/negative_filtering/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/radfact/llm_utils/negative_filtering/processor.py b/src/radfact/llm_utils/negative_filtering/processor.py new file mode 100644 index 0000000..e6e949c --- /dev/null +++ b/src/radfact/llm_utils/negative_filtering/processor.py @@ -0,0 +1,140 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +from collections import defaultdict +import json +from pathlib import Path + +import pandas as pd +from radfact.llm_utils.prompt_tasks import NEGATIVE_FILTERING_PARSING_TASK, NegativeFilteringTaskOptions, ReportType +from omegaconf import DictConfig + +from radfact.llm_utils.engine.engine import LLMEngine, get_subfolder +from radfact.llm_utils.processor.structured_processor import StructuredProcessor, parse_examples_from_json +from radfact.llm_utils.report_to_phrases.schema import ( + ParsedReport, + PhraseList, + PhraseListExample, + SentenceWithRephrases, +) +from radfact.paths import OUTPUT_DIR + +ORIG = "orig" +NEW = "new" + + +def get_negative_filtering_phrase_processor( + report_type: ReportType, log_dir: Path | None = None +) -> StructuredProcessor[list[str], PhraseList]: + """Return a processor for filtering negative findings from a list of phrases. + + :param report_type: The type of report, e.g., "ReportType.CXR" or "ReportType.CT". + :param log_dir: The directory to save logs. + :return: The processor for negative finding filtering. + """ + task = NegativeFilteringTaskOptions[report_type.name].value + system_prompt = task.system_message_path.read_text() + few_shot_examples = parse_examples_from_json(task.few_shot_examples_path, PhraseListExample) + processor = StructuredProcessor( + query_type=list[str], + result_type=PhraseList, + system_prompt=system_prompt, + format_query_fn=lambda x: json.dumps(x), + few_shot_examples=few_shot_examples, + log_dir=log_dir, + ) + return processor + + +def load_filtering_queries_from_parsed_reports( + reports: list[ParsedReport], + index_col: str, +) -> pd.DataFrame: + """ + Load queries for filtering from a list of parsed reports. Queries consist of all the + newly parsed phrases from phrasification, along with metadata including the study ID + and original phrase. + + :param reports: A list of ParsedReport objects. + :param index_col: The column containing the index + :return: A dataframe of queries. + """ + queries = [] + for report in reports: + for i, sentence in enumerate(report.sentence_list): + queries.append([f"{report.id}_{i}", sentence.orig, sentence.new]) + query_df = pd.DataFrame(queries, columns=[index_col, ORIG, NEW]) + return query_df + + +def get_negative_filtering_engine( + cfg: DictConfig, parsed_reports: list[ParsedReport], subfolder_prefix: str, report_type: ReportType +) -> LLMEngine: + """ + Create the processing engine for filtering negative findings from parsed reports. + + :param cfg: The configuration for the processing engine. + :param parsed_reports: A list of ParsedReport objects to filter. + :param subfolder_prefix: The prefix for the metric folder + :param report_type: The type of report, e.g., CT. + :return: The processing engine. + """ + OUTPUT_FOLDER = OUTPUT_DIR / NEGATIVE_FILTERING_PARSING_TASK + output_folder = get_subfolder(OUTPUT_FOLDER, subfolder_prefix) + final_output_folder = get_subfolder(OUTPUT_FOLDER, subfolder_prefix) + log_dir = get_subfolder(OUTPUT_FOLDER, "logs") + + query_df = load_filtering_queries_from_parsed_reports(parsed_reports, cfg.processing.index_col) + negative_filtering_processor = get_negative_filtering_phrase_processor(report_type=report_type, log_dir=log_dir) + + engine = LLMEngine( + cfg=cfg, + processor=negative_filtering_processor, + dataset_df=query_df, + row_to_query_fn=lambda row: row[NEW], + progress_output_folder=output_folder, + final_output_folder=final_output_folder, + ) + return engine + + +def process_filtered_reports(engine: LLMEngine, cfg: DictConfig) -> tuple[list[ParsedReport], int]: + """ + Process the filtered reports using the provided engine. + + :param engine: The LLMEngine used for processing. + :param cfg: The configuration for negative filtering processing. + :return: A tuple containing a list of ParsedReport objects and the number of rewritten sentences. + """ + outputs = engine.return_raw_outputs + metadata = engine.return_dataset_subsets + + parsed_report_dict = defaultdict(list) + num_rewritten_sentences = 0 + + for k in outputs.keys(): + phrase_list = outputs[k] + metadata_df = metadata[k].df + + for idx, row in metadata_df.iterrows(): + study_id = row[cfg.processing.index_col].rsplit("_", 1)[0] + orig = row[ORIG] + unfiltered_phrases = set(row[NEW]) + filtered_phrases = set(phrase_list[idx].phrases) + + if not filtered_phrases.issubset(unfiltered_phrases): + rewritten_phrases = filtered_phrases - unfiltered_phrases + print( + f"New phrases {rewritten_phrases} not in original phrases {unfiltered_phrases}. Reverting back to original phrases." + ) + filtered_phrases = unfiltered_phrases + num_rewritten_sentences += 1 + + parsed_report_dict[study_id].append(SentenceWithRephrases(orig=orig, new=list(filtered_phrases))) + + parsed_reports = [ + ParsedReport(id=study_id, sentence_list=sentences) for study_id, sentences in parsed_report_dict.items() + ] + return parsed_reports, num_rewritten_sentences diff --git a/src/radfact/llm_utils/negative_filtering/prompts/ct/few_shot_examples.json b/src/radfact/llm_utils/negative_filtering/prompts/ct/few_shot_examples.json new file mode 100644 index 0000000..cfd6fd3 --- /dev/null +++ b/src/radfact/llm_utils/negative_filtering/prompts/ct/few_shot_examples.json @@ -0,0 +1,44 @@ +[ + { + "input": [ + "There is a relative opacity observed in the left mid-to-lower lung, possibly located in the lingula.", + "There is no evidence of pneumothorax.", + "The cardiac silhouette is unremarkable.", + "The mediastinal silhouette is unremarkable.", + "Mild recessions are observed in the upper lobe of the left lung." + ], + "output": { + "phrases": [ + "There is a relative opacity observed in the left mid-to-lower lung, possibly located in the lingula.", + "Mild recessions are observed in the upper lobe of the left lung." + ] + } + }, { + "input": [ + "The right lung is well aerated.", + "No signs of pulmonary edema.", + "No signs of focal consolidation.", + "The left side still shows mediastinal shifting and volume loss.", + "No signs of pleural effusions." + ], + "output": { + "phrases": [ + "The left side still shows mediastinal shifting and volume loss." + ] + } + }, { + "input": [ + "There is a moderate right pleural effusion.", + "There is no pneumothorax.", + "The heart size is within normal limits.", + "The radiograph shows linear opacities in the right middle lobe and left lower lobe, indicating atelectasis.", + "The mediastinal contours are unremarkable." + ], + "output": { + "phrases": [ + "There is a moderate right pleural effusion.", + "The radiograph shows linear opacities in the right middle lobe and left lower lobe, indicating atelectasis." + ] + } + } +] \ No newline at end of file diff --git a/src/radfact/llm_utils/negative_filtering/prompts/ct/system_message.txt b/src/radfact/llm_utils/negative_filtering/prompts/ct/system_message.txt new file mode 100644 index 0000000..9b01b2a --- /dev/null +++ b/src/radfact/llm_utils/negative_filtering/prompts/ct/system_message.txt @@ -0,0 +1,13 @@ +You are an AI radiology assistant. You are helping process reports from CT (computed tomography) scans. + +You are given a list of phrases from a radiology report which refer to objects, findings, or anatomies visible in a CT scan, or the absence of such. + +Your goal is to filter phrases that do not refer to positive radiology findings. + +Rules: +- Remove statements describing the absence of pathology (e.g. "No pneumothorax", "No pleural effusion detected") +- Remove statements describing normal anatomical appearance, calibration, or function (e.g. "The liver is normal in size", "Upper abdominal organs are normal", "Thoracic esophageal calibration was normal", "The lungs are well aerated", "Lungs are clear") +- Remove statements describing unremarkable appearances (e.g. "Kidneys appear unremarkable", "The mediastinum is unremarkable") +- Keep statements referring to "mild" observations or conditions, as those are still considered positive radiology findings + +The objective is to remove phrases which do not refer to positive radiology findings. \ No newline at end of file diff --git a/src/radfact/llm_utils/processor/structured_processor.py b/src/radfact/llm_utils/processor/structured_processor.py index dc3256d..0b8667f 100644 --- a/src/radfact/llm_utils/processor/structured_processor.py +++ b/src/radfact/llm_utils/processor/structured_processor.py @@ -3,6 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import json import logging from enum import Enum from functools import partial @@ -22,6 +23,7 @@ _QUERY_KEY = "query" ResultT = TypeVar("ResultT", bound=BaseModel) +ExampleClassT = TypeVar("ExampleClassT", bound=BaseModel) ProcessorStats = dict[str, int] @@ -55,6 +57,34 @@ class Example(Protocol, Generic[QueryT, ResultT]): output: ResultT +def parse_examples_from_json(examples_path: Path | None, example_class: type[ExampleClassT]) -> list[ExampleClassT]: + """ + This function returns a list of "parsed" examples from a JSON file. + + This JSON file is expected to contain a list of JSON-formatted objects, which should + be parseable by the "example class" (expected to be some Pydantic model). + + If no path is provided, an empty list is returned. + + This function is especially useful for loading few-shot examples for a structured processor. + + :param examples_path: Path to the JSON file containing the examples. + If None, an empty list is returned. + :param example_class: The class of the examples to load. A Pydantic model. + We will attempt to parse each object in the JSON file as an instance of this class. + :return: List of examples, as instances of the provided class. + """ + parsed_examples: list[ExampleClassT] = [] + if examples_path is None: + return parsed_examples + + with open(examples_path) as f: + unparsed_examples = json.load(f) + for example in unparsed_examples: + parsed_examples.append(example_class.parse_obj(example)) + return parsed_examples + + class QueryTemplate(BaseChatPromptTemplate, Generic[QueryT, ResultT]): """Query template for a structured processor.""" diff --git a/src/radfact/llm_utils/prompt_tasks.py b/src/radfact/llm_utils/prompt_tasks.py index 6067b67..5ca5d04 100644 --- a/src/radfact/llm_utils/prompt_tasks.py +++ b/src/radfact/llm_utils/prompt_tasks.py @@ -5,6 +5,8 @@ REPORT_TO_PHRASES_PARSING_TASK = "report_to_phrases" REPORT_TO_PHRASES_PROMPTS_DIR = get_prompts_dir(task=REPORT_TO_PHRASES_PARSING_TASK) +NEGATIVE_FILTERING_PARSING_TASK = "negative_filtering" +NEGATIVE_FILTERING_PROMPTS_DIR = get_prompts_dir(task=NEGATIVE_FILTERING_PARSING_TASK) NLI_PARSING_TASK = "nli" NLI_PROMPTS_DIR = get_prompts_dir(task=NLI_PARSING_TASK) @@ -34,6 +36,14 @@ class ReportToPhrasesTaskOptions(Enum): ) +class NegativeFilteringTaskOptions(Enum): + CT = PromptTask( + name=f"{ReportType.CT.value}_negative_filtering", + system_message_path=NEGATIVE_FILTERING_PROMPTS_DIR / ReportType.CT.value / "system_message.txt", + few_shot_examples_path=NEGATIVE_FILTERING_PROMPTS_DIR / ReportType.CT.value / "few_shot_examples.json", + ) + + class NLITaskOptions(Enum): CXR = PromptTask( name=f"{ReportType.CXR.value}_nli", diff --git a/src/radfact/llm_utils/report_to_phrases/processor.py b/src/radfact/llm_utils/report_to_phrases/processor.py index 39d623f..05a3331 100644 --- a/src/radfact/llm_utils/report_to_phrases/processor.py +++ b/src/radfact/llm_utils/report_to_phrases/processor.py @@ -7,7 +7,11 @@ from typing import Any import pandas as pd -from radfact.llm_utils.prompt_tasks import REPORT_TO_PHRASES_PARSING_TASK, ReportToPhrasesTaskOptions, ReportType +from radfact.llm_utils.prompt_tasks import ( + REPORT_TO_PHRASES_PARSING_TASK, + ReportToPhrasesTaskOptions, + ReportType, +) from omegaconf import DictConfig from radfact.llm_utils.engine.engine import LLMEngine, get_subfolder @@ -50,18 +54,19 @@ def get_findings_from_row(row: "pd.Series[Any]") -> str: def get_report_to_phrases_engine( - cfg: DictConfig, dataset_df: pd.DataFrame, report_type: ReportType = ReportType.CXR + cfg: DictConfig, dataset_df: pd.DataFrame, subfolder_prefix: str = "", report_type: ReportType = ReportType.CXR ) -> LLMEngine: """ Create the processing engine for converting reports to phrases. :param cfg: The configuration for the processing engine. :param dataset_df: The dataset DataFrame. + :param subfolder_prefix: The prefix for the metric folder :param report_type: The type of report, e.g., CXR or CT. :return: The processing engine. """ subfolder = cfg.dataset.name - root = OUTPUT_DIR / REPORT_TO_PHRASES_PARSING_TASK + root = OUTPUT_DIR / REPORT_TO_PHRASES_PARSING_TASK / subfolder_prefix output_folder = get_subfolder(root, subfolder) final_output_folder = get_subfolder(root, subfolder) log_dir = get_subfolder(root, "logs") diff --git a/src/radfact/llm_utils/report_to_phrases/schema.py b/src/radfact/llm_utils/report_to_phrases/schema.py index 06579ff..ef9e92e 100644 --- a/src/radfact/llm_utils/report_to_phrases/schema.py +++ b/src/radfact/llm_utils/report_to_phrases/schema.py @@ -13,6 +13,12 @@ from radfact.llm_utils.processor.base_processor import BaseModelWithId +class PhraseList(BaseModel): + """Dataclass for a list of phrases.""" + + phrases: list[str] + + class SentenceWithRephrases(BaseModel): """Dataclass for a sentence with rephrases. The source sentence is 'orig' and the rephrased sentences are 'new'.""" @@ -71,6 +77,13 @@ def to_grounded_phrases_list(self, rephrased: bool = True) -> GroundedPhraseList return sequence +class PhraseListExample(BaseModel): + """A single example of a list of phrases before and after processing""" + + input: list[str] + output: PhraseList + + class PhraseParsingExample(BaseModel): """Dataclass for a single example.""" diff --git a/src/radfact/metric/radfact.py b/src/radfact/metric/radfact.py index 0dd93e5..0b7f2b7 100644 --- a/src/radfact/metric/radfact.py +++ b/src/radfact/metric/radfact.py @@ -18,6 +18,7 @@ from radfact.llm_utils.nli.processor import get_report_nli_engine from radfact.llm_utils.nli.schema import EVState, NLISample from radfact.llm_utils.report_to_phrases.processor import FINDINGS_SECTION, StudyIdType, get_report_to_phrases_engine +from radfact.llm_utils.negative_filtering.processor import get_negative_filtering_engine, process_filtered_reports from radfact.llm_utils.report_to_phrases.schema import ParsedReport from radfact.metric.box_metrics import PRECISION, compute_box_metrics from radfact.metric.schema import ( @@ -45,6 +46,11 @@ RADFACT_CONFIG = "radfact.yaml" # The YAML config file for the phrase processor in this setting. REPORT_TO_PHRASES_CONFIG = "report_to_phrases.yaml" +# The YAML config file for the negative filtering processor in this setting. +NEGATIVE_FILTERING_CONFIG = "negative_filtering.yaml" + +GENERATIONS = "generations" +GROUND_TRUTH = "ground_truth" def init_hydra_config(config_name: str) -> DictConfig: @@ -69,10 +75,12 @@ def __init__( self, nli_config_name: str | None = None, phrase_config_name: str | None = None, + filtering_config_name: str | None = None, image_size: int = 224, box_precision_threshold: float = 0.5, is_narrative_text: bool = False, report_type: ReportType = ReportType.CXR, + filter_negatives: bool = False, ) -> None: """ Initializes the RadFactMetric with the necessary configurations. We need to know the image size so we can @@ -82,6 +90,9 @@ def __init__( different endpoints that the NLI processor will use. If None, the default config will be used. :param phrase_config_name: The name of the phrase processing config file. This is the config file that specifies the different endpoints that the phrase processor will use. If None, the default config will be used. + :param filtering_config_name: The name of the negative filtering processing config file. This is the config file + that specifies the different endpoints that the negative filtering processor will use. If None, the default config + will be used. :param image_size: The size of the images in the reports. :param box_precision_threshold: The threshold for precision computation for boxes. :param is_narrative_text: If True, we are running the metric on data narrative text data, e.g. the original @@ -89,14 +100,20 @@ def __init__( If False, we are running the metric on grounded reports, where the phrases are already in the correct format for entailment verification. :param report_type: The type of report, e.g. CXR or CT + :param filter_negatives: If True, we will filter negative findings from the parsed reports before computing + the RadFact score. """ self.llm_nli_cfg = init_hydra_config(nli_config_name or RADFACT_CONFIG) self.llm_phrase_cfg = init_hydra_config(phrase_config_name or REPORT_TO_PHRASES_CONFIG) + self.llm_negative_filtering_cfg = init_hydra_config(filtering_config_name or NEGATIVE_FILTERING_CONFIG) self.report_type = report_type self.image_size = image_size self.box_precision_threshold = box_precision_threshold self.is_narrative_text = is_narrative_text self.meta_metrics: dict[str, float] = {} # Metrics about the metric, derived from processors. Not per-sample. + self.filter_negatives = filter_negatives + if self.filter_negatives: + assert self.report_type == ReportType.CT, "Negative filtering is only supported for CT reports." def _are_boxes_entailed(self, boxes: list[NormalizedBox] | None, evidence_boxes: list[NormalizedBox]) -> bool: """ @@ -210,15 +227,46 @@ def convert_narrative_text_to_phrases( texts_as_str_df = pd.DataFrame( {id_col: study_id, FINDINGS_SECTION: texts_as_str[study_id]} for study_id in texts_as_str.keys() ) - engine = get_report_to_phrases_engine(self.llm_phrase_cfg, texts_as_str_df, self.report_type) + + if metric_prefix.endswith(GENERATIONS): + subfolder_prefix = GENERATIONS + elif metric_prefix.endswith(GROUND_TRUTH): + subfolder_prefix = GROUND_TRUTH + else: + subfolder_prefix = "" + + engine = get_report_to_phrases_engine(self.llm_phrase_cfg, texts_as_str_df, subfolder_prefix, self.report_type) parsed_reports: list[ParsedReport] = engine.run() - processed_texts = { - parsed.id: parsed.to_grounded_phrases_list() for parsed in parsed_reports if parsed.id is not None - } + if engine.aggregated_processor_stats is not None: self.meta_metrics.update( {f"{metric_prefix}/{k}": float(v) for k, v in engine.aggregated_processor_stats.items()} ) + + if self.filter_negatives: + logger.info("Filtering negatives from previous run.") + engine = get_negative_filtering_engine( + self.llm_negative_filtering_cfg, + parsed_reports, + subfolder_prefix, + report_type=self.report_type, + ) + engine.run() + parsed_reports, num_rewritten_sentences = process_filtered_reports(engine, self.llm_negative_filtering_cfg) + if engine.aggregated_processor_stats is not None: + self.meta_metrics.update( + { + f"negative_filtering_{metric_prefix}/{k}": float(v) + for k, v in engine.aggregated_processor_stats.items() + } + ) + self.meta_metrics[f"negative_filtering_{metric_prefix}/num_rewritten_sentences"] = ( + num_rewritten_sentences + ) + + processed_texts = { + parsed.id: parsed.to_grounded_phrases_list() for parsed in parsed_reports if parsed.id is not None + } if set(processed_texts.keys()) != set(texts.keys()): logger.warning( f"Key mismatch between processed and input texts. #input keys: {len(set(texts.keys()))}. #processed " diff --git a/tests/metric/test_radfact.py b/tests/metric/test_radfact.py index 776614f..ba31f61 100644 --- a/tests/metric/test_radfact.py +++ b/tests/metric/test_radfact.py @@ -3,11 +3,15 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import copy +import functools import shutil from pathlib import Path import mock import pandas as pd +from radfact.llm_utils.report_to_phrases.schema import PhraseList +from radfact.llm_utils.engine.data_subset import DataSubset from radfact.llm_utils.nli.processor import get_ev_processor_singlephrase from radfact.paths import OUTPUT_DIR from radfact.llm_utils.prompt_tasks import ReportType @@ -253,7 +257,9 @@ def test_nli_processing_with_endpoint(mock_nli_engine: mock.Mock) -> None: } -def get_mock_phrase_engine(llm_phrase_cfg: DictConfig, df: pd.DataFrame, report_type: ReportType) -> mock.Mock: +def get_mock_phrase_engine( + llm_phrase_cfg: DictConfig, df: pd.DataFrame, subfolder_prefix: str, report_type: ReportType +) -> mock.Mock: mock_phrase_engine = mock.Mock() if df["FINDINGS"].values[0] == "The cat The dog The bird The rabbit": mock_phrase_engine.run.return_value = [ @@ -283,6 +289,57 @@ def get_mock_phrase_engine(llm_phrase_cfg: DictConfig, df: pd.DataFrame, report_ return mock_phrase_engine +def get_mock_filtering_engine( + llm_negative_filtering_cfg: DictConfig, + parsed_reports: list[ParsedReport], + subfolder_prefix: str, + report_type: ReportType, + tmp_path: Path, +) -> mock.Mock: + mock_filtering_engine = mock.Mock() + new_parsed_reports = copy.deepcopy(parsed_reports) + + # Inject a rewritten phrase to simulate filtering mistake + new_parsed_reports[0].sentence_list[0].new.append("Dummy filtered phrase") + mock_filtering_engine.return_raw_outputs = { + "endpoint_1": [PhraseList(phrases=report.sentence_list[0].new) for report in new_parsed_reports] + } + if parsed_reports[0].sentence_list[0].orig == "The cat The dog The bird The rabbit": + mock_filtering_engine.return_dataset_subsets = { + "endpoint_1": DataSubset( + start_index=0, + end_index=1, + index_col="sentence_id", + output_folder=tmp_path, + df=pd.DataFrame( + { + "sentence_id": ["study1_0"], + "orig": ["The cat The dog The bird The rabbit"], + "new": [["The cat", "The dog", "The bird", "The rabbit"]], + } + ), + ) + } + else: + mock_filtering_engine.return_dataset_subsets = { + "endpoint_1": DataSubset( + start_index=0, + end_index=1, + index_col="sentence_id", + output_folder=tmp_path, + df=pd.DataFrame( + { + "sentence_id": ["study1_0"], + "orig": ["The cat The dog The bird The shark"], + "new": [["The cat", "The dog", "The bird", "The shark"]], + } + ), + ) + } + mock_filtering_engine.aggregated_processor_stats = {'num_failures': 0, 'num_success': 1} + return mock_filtering_engine + + def test_nli_processing_with_endpoint_and_report_to_phrases(mock_nli_engine: mock.Mock) -> None: """Test that the RadFact metric works end-to-end, with a mocked engine including report-to-phrases processing.""" progress_subfolder = Path(LLMEngine.OUTPUT_FILES_PREFIX) / RADFACT_SUBFOLDER @@ -323,6 +380,70 @@ def test_nli_processing_with_endpoint_and_report_to_phrases(mock_nli_engine: moc assert_equal(actual=details, desired=expected_details, verbose=True) +def test_nli_processing_with_negative_filtering(mock_nli_engine: mock.Mock, tmp_path: Path) -> None: + """Test that the GPT metric works end-to-end, when connecting to an actual endpoint, with the Redis Cache, + phrasification, and negative filtering. + """ + progress_subfolder = Path(LLMEngine.OUTPUT_FILES_PREFIX) / RADFACT_SUBFOLDER + shutil.rmtree(progress_subfolder, ignore_errors=True) + metric = RadFactMetric(is_narrative_text=True, report_type=ReportType.CT, filter_negatives=True) + with mock.patch('radfact.metric.radfact.get_report_nli_engine', return_value=mock_nli_engine): + with mock.patch('radfact.metric.radfact.get_report_to_phrases_engine', side_effect=get_mock_phrase_engine): + with mock.patch( + 'radfact.metric.radfact.get_negative_filtering_engine', + side_effect=functools.partial(get_mock_filtering_engine, tmp_path=tmp_path), + ): + result, details = metric.compute_metric_score(candidates_narrative, references_narrative) + + assert isinstance(result, float) + assert result == 0.75 + assert isinstance(details, dict) + + expected_details = { + "logical_precision": 0.75, + "logical_recall": 0.75, + "spatial_precision": 0.0, + "spatial_recall": 0.0, + "grounding_precision": 0.0, + "grounding_recall": 0.0, + "num_candidate_phrases": 4, + "num_reference_phrases": 4, + "num_candidate_phrases_with_boxes": 0, + "num_reference_phrases_with_boxes": 0, + "logical_f1": 0.75, + "spatial_f1": 0.0, + "grounding_f1": 0.0, + "num_samples": 1, + "num_llm_failures": 0, + "num_llm_success": 8, + "num_llm_phrase_rewrites": 0, + "num_invalid_processed_samples": 0, + "report_to_phrases/generations/num_failures": 0, + "report_to_phrases/generations/num_success": 1, + 'negative_filtering_report_to_phrases/generations/num_failures': 0.0, + 'negative_filtering_report_to_phrases/generations/num_success': 1.0, + 'negative_filtering_report_to_phrases/generations/num_rewritten_sentences': 1, + "report_to_phrases/ground_truth/num_failures": 0, + "report_to_phrases/ground_truth/num_success": 1, + 'negative_filtering_report_to_phrases/ground_truth/num_failures': 0.0, + 'negative_filtering_report_to_phrases/ground_truth/num_success': 1.0, + 'negative_filtering_report_to_phrases/ground_truth/num_rewritten_sentences': 1, + "report_to_phrases/num_dropped_candidates": 0, + "report_to_phrases/num_dropped_references": 0, + } + assert_equal(actual=details, desired=expected_details, verbose=True) + + +def test_nli_processing_fails_with_negative_filtering_config_error() -> None: + """ + Test that an error is raised if negative filtering is enabled but the report type is not CT. + """ + progress_subfolder = Path(LLMEngine.OUTPUT_FILES_PREFIX) / RADFACT_SUBFOLDER + shutil.rmtree(progress_subfolder, ignore_errors=True) + with pytest.raises(AssertionError): + RadFactMetric(is_narrative_text=True, report_type=ReportType.CXR, filter_negatives=True) + + def test_convert_input_to_multimodal() -> None: """ Test that we can convert the input to a multimodal grounded sequence correctly.