From 6026a1dcbab86fdcc2ce39a0bb907472866b479e Mon Sep 17 00:00:00 2001 From: DShomin Date: Wed, 16 Apr 2025 17:50:03 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20=ED=85=8C=EC=9D=B4=EB=B8=94=20=EB=B0=8F?= =?UTF-8?q?=20=EC=BB=AC=EB=9F=BC=20=EC=A0=95=EB=B3=B4=EB=A5=BC=20=EB=B3=91?= =?UTF-8?q?=EB=A0=AC=EB=A1=9C=20=EC=88=98=EC=A7=91=ED=95=98=EB=8F=84?= =?UTF-8?q?=EB=A1=9D=20=EA=B0=9C=EC=84=A0=ED=95=98=EC=97=AC=20=EC=84=B1?= =?UTF-8?q?=EB=8A=A5=20=ED=96=A5=EC=83=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/tools.py | 129 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 106 insertions(+), 23 deletions(-) diff --git a/llm_utils/tools.py b/llm_utils/tools.py index 24e8c1d..ddc9bb2 100644 --- a/llm_utils/tools.py +++ b/llm_utils/tools.py @@ -1,9 +1,40 @@ import os -from typing import List, Dict +from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any from langchain.schema import Document from data_utils.datahub_source import DatahubMetadataFetcher +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor + +T = TypeVar("T") +R = TypeVar("R") + + +def parallel_process[T, R]( + items: Iterable[T], + process_fn: Callable[[T], R], + max_workers: int = 8, + desc: Optional[str] = None, + show_progress: bool = True, +) -> List[R]: + """병렬 처리를 위한 유틸리티 함수 + + Args: + items (Iterable[T]): 처리할 아이템들 + process_fn (Callable[[T], R]): 각 아이템을 처리할 함수 + max_workers (int, optional): 최대 쓰레드 수. Defaults to 8. + desc (Optional[str], optional): 진행 상태 메시지. Defaults to None. + show_progress (bool, optional): 진행 상태 표시 여부. Defaults to True. + + Returns: + List[R]: 처리 결과 리스트 + """ + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_fn, item) for item in items] + if show_progress: + futures = tqdm(futures, desc=desc) + return [future.result() for future in futures] def set_gms_server(gms_server: str): @@ -11,7 +42,6 @@ def set_gms_server(gms_server: str): os.environ["DATAHUB_SERVER"] = gms_server fetcher = DatahubMetadataFetcher(gms_server=gms_server) except ValueError as e: - # 유효하지 않은 GMS 서버 주소일 경우 예외를 발생시킴 raise ValueError(f"GMS 서버 설정 실패: {str(e)}") @@ -22,47 +52,100 @@ def _get_fetcher(): return DatahubMetadataFetcher(gms_server=gms_server) -def _get_table_info() -> Dict[str, str]: - """전체 테이블 이름과 설명을 가져오는 함수""" +def _process_urn(urn: str, fetcher: DatahubMetadataFetcher) -> tuple[str, str]: + table_name = fetcher.get_table_name(urn) + table_description = fetcher.get_table_description(urn) + return (table_name, table_description) + + +def _process_column_info( + urn: str, table_name: str, fetcher: DatahubMetadataFetcher +) -> Optional[List[Dict[str, str]]]: + if fetcher.get_table_name(urn) == table_name: + return fetcher.get_column_names_and_descriptions(urn) + return None + + +def _get_table_info(max_workers: int = 8) -> Dict[str, str]: + """전체 테이블 이름과 설명을 가져오는 함수 + + Args: + max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8. + + Returns: + Dict[str, str]: 테이블 이름과 설명을 담은 딕셔너리 + """ fetcher = _get_fetcher() urns = fetcher.get_urns() table_info = {} - for urn in urns: - table_name = fetcher.get_table_name(urn) - table_description = fetcher.get_table_description(urn) + + results = parallel_process( + urns, + lambda urn: _process_urn(urn, fetcher), + max_workers=max_workers, + desc="테이블 정보 수집 중", + ) + + for table_name, table_description in results: if table_name and table_description: table_info[table_name] = table_description + return table_info -def _get_column_info(table_name: str) -> List[Dict[str, str]]: - """table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수""" +def _get_column_info(table_name: str, max_workers: int = 8) -> List[Dict[str, str]]: + """table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수 + + Args: + table_name (str): 테이블 이름 + max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8. + + Returns: + List[Dict[str, str]]: 컬럼 정보 리스트 + """ fetcher = _get_fetcher() urns = fetcher.get_urns() - for urn in urns: - if fetcher.get_table_name(urn) == table_name: - return fetcher.get_column_names_and_descriptions(urn) + + results = parallel_process( + urns, + lambda urn: _process_column_info(urn, table_name, fetcher), + max_workers=max_workers, + show_progress=False, + ) + + for result in results: + if result: + return result return [] -def get_info_from_db() -> List[Document]: - """ - 전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수 +def get_info_from_db(max_workers: int = 8) -> List[Document]: + """전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수 + + Args: + max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8. + + Returns: + List[Document]: 테이블과 컬럼 정보를 담은 Document 객체 리스트 """ + table_info = _get_table_info(max_workers=max_workers) - table_info_str_list = [] - table_info = _get_table_info() - for table_name, table_description in table_info.items(): - column_info = _get_column_info(table_name) + def process_table_info(item: tuple[str, str]) -> str: + table_name, table_description = item + column_info = _get_column_info(table_name, max_workers=max_workers) column_info_str = "\n".join( [ f"{col['column_name']}: {col['column_description']}" for col in column_info ] ) - table_info_str_list.append( - f"{table_name}: {table_description}\nColumns:\n {column_info_str}" - ) + return f"{table_name}: {table_description}\nColumns:\n {column_info_str}" + + table_info_str_list = parallel_process( + table_info.items(), + process_table_info, + max_workers=max_workers, + desc="컬럼 정보 수집 중", + ) - # table_info_str_list를 Document 객체 리스트로 변환 return [Document(page_content=info) for info in table_info_str_list]