Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 106 additions & 23 deletions llm_utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
import os
from typing import List, Dict
from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any

from langchain.schema import Document

from data_utils.datahub_source import DatahubMetadataFetcher
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

T = TypeVar("T")
R = TypeVar("R")


def parallel_process[T, R](
items: Iterable[T],
process_fn: Callable[[T], R],
max_workers: int = 8,
desc: Optional[str] = None,
show_progress: bool = True,
) -> List[R]:
"""병렬 처리를 위한 유틸리티 함수

Args:
items (Iterable[T]): 처리할 아이템들
process_fn (Callable[[T], R]): 각 아이템을 처리할 함수
max_workers (int, optional): 최대 쓰레드 수. Defaults to 8.
desc (Optional[str], optional): 진행 상태 메시지. Defaults to None.
show_progress (bool, optional): 진행 상태 표시 여부. Defaults to True.

Returns:
List[R]: 처리 결과 리스트
"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_fn, item) for item in items]
if show_progress:
futures = tqdm(futures, desc=desc)
return [future.result() for future in futures]


def set_gms_server(gms_server: str):
try:
os.environ["DATAHUB_SERVER"] = gms_server
fetcher = DatahubMetadataFetcher(gms_server=gms_server)
except ValueError as e:
# 유효하지 않은 GMS 서버 주소일 경우 예외를 발생시킴
raise ValueError(f"GMS 서버 설정 실패: {str(e)}")


Expand All @@ -22,47 +52,100 @@ def _get_fetcher():
return DatahubMetadataFetcher(gms_server=gms_server)


def _get_table_info() -> Dict[str, str]:
"""전체 테이블 이름과 설명을 가져오는 함수"""
def _process_urn(urn: str, fetcher: DatahubMetadataFetcher) -> tuple[str, str]:
table_name = fetcher.get_table_name(urn)
table_description = fetcher.get_table_description(urn)
return (table_name, table_description)


def _process_column_info(
urn: str, table_name: str, fetcher: DatahubMetadataFetcher
) -> Optional[List[Dict[str, str]]]:
if fetcher.get_table_name(urn) == table_name:
return fetcher.get_column_names_and_descriptions(urn)
return None


def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
"""전체 테이블 이름과 설명을 가져오는 함수

Args:
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.

Returns:
Dict[str, str]: 테이블 이름과 설명을 담은 딕셔너리
"""
fetcher = _get_fetcher()
urns = fetcher.get_urns()
table_info = {}
for urn in urns:
table_name = fetcher.get_table_name(urn)
table_description = fetcher.get_table_description(urn)

results = parallel_process(
urns,
lambda urn: _process_urn(urn, fetcher),
max_workers=max_workers,
desc="테이블 정보 수집 중",
)

for table_name, table_description in results:
if table_name and table_description:
table_info[table_name] = table_description

return table_info


def _get_column_info(table_name: str) -> List[Dict[str, str]]:
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수"""
def _get_column_info(table_name: str, max_workers: int = 8) -> List[Dict[str, str]]:
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수

Args:
table_name (str): 테이블 이름
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.

Returns:
List[Dict[str, str]]: 컬럼 정보 리스트
"""
fetcher = _get_fetcher()
urns = fetcher.get_urns()
for urn in urns:
if fetcher.get_table_name(urn) == table_name:
return fetcher.get_column_names_and_descriptions(urn)

results = parallel_process(
urns,
lambda urn: _process_column_info(urn, table_name, fetcher),
max_workers=max_workers,
show_progress=False,
)

for result in results:
if result:
return result
return []


def get_info_from_db() -> List[Document]:
"""
전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수
def get_info_from_db(max_workers: int = 8) -> List[Document]:
"""전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수

Args:
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.

Returns:
List[Document]: 테이블과 컬럼 정보를 담은 Document 객체 리스트
"""
table_info = _get_table_info(max_workers=max_workers)

table_info_str_list = []
table_info = _get_table_info()
for table_name, table_description in table_info.items():
column_info = _get_column_info(table_name)
def process_table_info(item: tuple[str, str]) -> str:
table_name, table_description = item
column_info = _get_column_info(table_name, max_workers=max_workers)
column_info_str = "\n".join(
[
f"{col['column_name']}: {col['column_description']}"
for col in column_info
]
)
table_info_str_list.append(
f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
)
return f"{table_name}: {table_description}\nColumns:\n {column_info_str}"

table_info_str_list = parallel_process(
table_info.items(),
process_table_info,
max_workers=max_workers,
desc="컬럼 정보 수집 중",
)

# table_info_str_list를 Document 객체 리스트로 변환
return [Document(page_content=info) for info in table_info_str_list]