diff --git a/src/exploit_iq_commons/utils/transitive_code_searcher_tool.py b/src/exploit_iq_commons/utils/transitive_code_searcher_tool.py index 8167f440..8938bd01 100644 --- a/src/exploit_iq_commons/utils/transitive_code_searcher_tool.py +++ b/src/exploit_iq_commons/utils/transitive_code_searcher_tool.py @@ -140,4 +140,4 @@ def search(self, query: str) -> tuple[bool, list[Document]]: f"-------------------------------------------\n{function_method.page_content}\n" logger.debug(content_of_files_in_path) logger.debug(content_of_files_in_path, extra=MULTI_LINE_MESSAGE_TRUE) - return found_path, call_hierarchy_list \ No newline at end of file + return found_path, call_hierarchy_list diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py index c829f6e8..bffe731b 100644 --- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py +++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py @@ -126,7 +126,7 @@ def set_input_for_next_run(git_repository: str, git_ref: str, included_extension async def get_transitive_code_runner_function(): - transitive_code_search = transitive_search(config=TransitiveCodeSearchToolConfig(), builder=None) + transitive_code_search = transitive_search(config=TransitiveCodeSearchToolConfig(), builder=None, verbose_mode=True) async for function in transitive_code_search.gen: return function.single_fn diff --git a/src/vuln_analysis/tools/transitive_code_search.py b/src/vuln_analysis/tools/transitive_code_search.py index 6106b317..3e2b25d4 100644 --- a/src/vuln_analysis/tools/transitive_code_search.py +++ b/src/vuln_analysis/tools/transitive_code_search.py @@ -108,7 +108,7 @@ def get_transitive_code_searcher(query: str): @register_function(config_type=TransitiveCodeSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def transitive_search(config: TransitiveCodeSearchToolConfig, - builder: Builder): # pylint: disable=unused-argument + builder: Builder, verbose_mode: bool = False): # pylint: disable=unused-argument """ Call Chain Analyzer tool used to search source code function reachability. """ @@ -117,7 +117,17 @@ async def _arun(query: str) -> tuple: transitive_code_searcher: TransitiveCodeSearcher transitive_code_searcher = get_transitive_code_searcher(query) result = transitive_code_searcher.search(query) - return result + (answer, docs) = result + call_hierarchy_list_strings: list + if verbose_mode: + return result + else: + call_hierarchy_list_strings = list(map(lambda doc: + transitive_code_searcher + .chain_of_calls_retriever + .language_parser.get_function_name(doc) + ,docs)) + return answer, call_hierarchy_list_strings yield FunctionInfo.from_fn( _arun,