diff --git a/CHANGELOG.md b/CHANGELOG.md index bbe7d69..8234ffe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,21 @@ -## [unreleased] +## [0.3.0] - 02.11.2025 ### 🚀 Features -- Added TextEmbeddings transformer, for compute embedding using SentenceTransformers +- Added [TextEmbeddings](https://scaledp.stabrise.com/en/latest/models/embeddings/TextEmbeddings.html) transformer, for compute embedding using SentenceTransformers +- Added BaseTextSplitter and [TextSplitter](https://scaledp.stabrise.com/en/latest/models/splitters/text_splitter.html) for semantic splitting text +- Added support pandas udf for TextSplitter +- Added support TextChunks as input to TextEmbeddings + +### 📚 Documentation + +- Added TextEmbedding and TextSplitter docs + +### 📘 Jupyter Notebooks + +- [TextSplitterAndEmbeddings.ipynb](https://github.com/StabRise/ScaleDP-Tutorials/blob/master/embeddings/1.TextSplitterAndEmbeddings.ipynb +) - Read pdf documents, split text into chunks and compute embeddings + ## [0.2.6] - 19.11.2025 diff --git a/pyproject.toml b/pyproject.toml index 4329430..988b795 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scaledp" -version = "0.2.6" +version = "0.3.0" description = "ScaleDP is a library for processing documents and images using Apache Spark and LLMs" authors = ["Mykola Melnyk "] repository = "https://github.com/StabRise/scaledp" diff --git a/scaledp/__init__.py b/scaledp/__init__.py index 192ebdb..684a117 100644 --- a/scaledp/__init__.py +++ b/scaledp/__init__.py @@ -18,6 +18,7 @@ from scaledp.models.detectors.SignatureDetector import SignatureDetector from scaledp.models.detectors.YoloDetector import YoloDetector from scaledp.models.detectors.YoloOnnxDetector import YoloOnnxDetector +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings from scaledp.models.extractors.DSPyExtractor import DSPyExtractor from scaledp.models.extractors.LLMExtractor import LLMExtractor from scaledp.models.extractors.LLMVisualExtractor import LLMVisualExtractor @@ -29,6 +30,7 @@ from scaledp.models.recognizers.SuryaOcr import SuryaOcr from scaledp.models.recognizers.TesseractOcr import TesseractOcr from scaledp.models.recognizers.TesseractRecognizer import TesseractRecognizer +from scaledp.models.splitters.TextSplitter import TextSplitter from scaledp.pdf.PdfAddTextLayer import PdfAddTextLayer from scaledp.pdf.PdfAssembler import PdfAssembler from scaledp.pdf.PdfDataToDocument import PdfDataToDocument @@ -214,6 +216,8 @@ def ScaleDPSession( "TesseractOcr", "Ner", "TextToDocument", + "TextSplitter", + "TextEmbeddings", "LayoutDetector", "PipelineModel", "SuryaOcr", diff --git a/scaledp/models/embeddings/TextEmbeddings.py b/scaledp/models/embeddings/TextEmbeddings.py index 43c51c3..f109c3a 100644 --- a/scaledp/models/embeddings/TextEmbeddings.py +++ b/scaledp/models/embeddings/TextEmbeddings.py @@ -1,4 +1,5 @@ import json +import time from dataclasses import asdict from types import MappingProxyType from typing import Any @@ -57,7 +58,7 @@ def _transform(self, dataset): # Handle TextChunks input - use explode to create multiple rows if not self.getPartitionMap(): # Use UDF with explode - result = dataset.withColumn( + result = dataset.select(in_col).withColumn( out_col, explode( udf( @@ -75,7 +76,7 @@ def _transform(self, dataset): dataset = dataset.repartition(self.getPathCol()) dataset = dataset.coalesce(self.getNumPartitions()) - result = dataset.withColumn( + result = dataset.select(in_col).withColumn( out_col, explode( pandas_udf( @@ -130,38 +131,55 @@ def _is_text_chunks_column(self, dataset, col_name: str) -> bool: def transform_udf(self, text: str): model = self.get_model() + start_time = time.time() embedding = model.encode( text, batch_size=self.getBatchSize(), device=self.getSTDevice(), ) + processing_time = time.time() - start_time return EmbeddingsOutput( path="memory", + page=0, + text=text, data=embedding.tolist(), type="text", + processing_time=processing_time, exception="", ) def transform_udf_chunks(self, text_chunks: TextChunks): - """Transform TextChunks into embeddings, preserving path information.""" + """Transform TextChunks into embeddings, preserving path information + and per-item processing time (batched). + """ if not text_chunks or not text_chunks.chunks: return [] + start_time = time.time() model = self.get_model() embeddings = model.encode( text_chunks.chunks, batch_size=self.getBatchSize(), device=self.getSTDevice(), ) + total_processing_time = time.time() - start_time + per_item_time = ( + total_processing_time / len(text_chunks.chunks) + if text_chunks.chunks + else 0.0 + ) + results = [] - for embedding in embeddings: + for i, embedding in enumerate(embeddings): results.append( EmbeddingsOutput( path=text_chunks.path or "memory", data=embedding.tolist(), + page=text_chunks.page, + text=text_chunks.chunks[i], type="text_chunk", exception=text_chunks.exception or "", - processing_time=text_chunks.processing_time or 0.0, + processing_time=per_item_time, ), ) return results @@ -170,18 +188,26 @@ def transform_udf_chunks(self, text_chunks: TextChunks): def transform_udf_pandas(texts: pd.Series, params: pd.Series) -> pd.DataFrame: params = json.loads(params[0]) model = SentenceTransformer(params["model"]) + start_time = time.time() embeddings = model.encode( texts.tolist(), batch_size=params["batchSize"], device="cpu" if params["device"] == Device.CPU.value else "cuda", ) + total_processing_time = time.time() - start_time + per_item_time = ( + total_processing_time / texts.shape[0] if texts is not None else 0.0 + ) results = [] - for embedding in embeddings: + for i, embedding in enumerate(embeddings): results.append( EmbeddingsOutput( path="memory", + page=0, + text=texts.iloc[i], data=embedding.tolist(), type="text", + processing_time=per_item_time, exception="", ), ) @@ -199,21 +225,26 @@ def transform_udf_pandas_chunks( results = [] for _, row in chunks_df.iterrows(): if len(row["chunks"]): + start_time = time.time() embeddings = model.encode( row["chunks"], batch_size=params["batchSize"], device="cpu" if params["device"] == Device.CPU.value else "cuda", ) + total_processing_time = time.time() - start_time + per_item_time = total_processing_time / len(row["chunks"]) emb_results = [] - for embedding in embeddings: + for i, embedding in enumerate(embeddings): emb_results.append( asdict( EmbeddingsOutput( path=row.get("path") or "memory", data=embedding.tolist(), + text=row["chunks"][i], + page=row.get("page"), type="text_chunk", exception=row.get("exception") or "", - processing_time=row.get("processing_time") or 0.0, + processing_time=per_item_time or 0.0, ), ), ) diff --git a/scaledp/models/splitters/BaseSplitter.py b/scaledp/models/splitters/BaseSplitter.py index 526e51f..226bc46 100644 --- a/scaledp/models/splitters/BaseSplitter.py +++ b/scaledp/models/splitters/BaseSplitter.py @@ -9,6 +9,7 @@ HasKeepInputData, HasNumPartitions, HasOutputCol, + HasPageCol, HasPartitionMap, HasWhiteList, ) @@ -24,6 +25,7 @@ class BaseSplitter( HasPartitionMap, HasChunkSize, HasChunkOverlap, + HasPageCol, ABC, ): """ diff --git a/scaledp/models/splitters/BaseTextSplitter.py b/scaledp/models/splitters/BaseTextSplitter.py index 806b043..a874a31 100644 --- a/scaledp/models/splitters/BaseTextSplitter.py +++ b/scaledp/models/splitters/BaseTextSplitter.py @@ -36,6 +36,7 @@ class BaseTextSplitter( "chunk_overlap": 0, "numPartitions": 1, "partitionMap": False, + "pageCol": "page_number", }, ) @@ -50,18 +51,19 @@ def get_params(self): return json.dumps({k.name: v for k, v in self.extractParamMap().items()}) @abstractmethod - def split(self, document: Document) -> TextChunks: + def split(self, document: Document, pagen_number: int) -> TextChunks: """ Split a document into chunks. Args: document: The document to split + pagen_number: The page number of the document Returns: TextChunks object containing the chunks and metadata """ - def transform_udf(self, document_struct): + def transform_udf(self, document_struct, page_number): """ Transform UDF that splits text into chunks. @@ -72,13 +74,14 @@ def transform_udf(self, document_struct): TextChunks object containing the chunks """ # document_struct is already a Document object - result = self.split(document_struct) + result = self.split(document_struct, page_number) return result @classmethod def transform_udf_pandas( cls, documents: pd.DataFrame, + page_numbers: pd.Series, params: pd.Series, ) -> pd.DataFrame: """ @@ -94,7 +97,7 @@ def transform_udf_pandas( params_dict = json.loads(params.iloc[0]) splitter = cls(**params_dict) results = [] - for _, doc_row in documents.iterrows(): + for i, doc_row in documents.iterrows(): # Convert Row to Document # When using pandas_udf with Arrow, the struct comes as # a Row object with field attributes @@ -104,7 +107,7 @@ def transform_udf_pandas( if isinstance(doc_row, Document) else Document(**doc_row.to_dict()) ) - output = splitter.split(doc) + output = splitter.split(doc, page_numbers.iloc[i]) except (AttributeError, TypeError, Exception) as e: # If something goes wrong, create an error result output = TextChunks( @@ -130,6 +133,7 @@ def _transform(self, dataset): params = self.get_params() out_col = self.getOutputCol() input_col = self.getInputCol() + page_col = self.getPageCol() # Validate input column exists if input_col not in dataset.columns: @@ -137,12 +141,16 @@ def _transform(self, dataset): # Validate input column validated_input_col = self._validate(input_col, dataset) + validated_page_col = self._validate(page_col, dataset) if not self.getPartitionMap(): # Regular mode: use UDF result = dataset.withColumn( out_col, - udf(self.transform_udf, TextChunks.get_schema())(validated_input_col), + udf(self.transform_udf, TextChunks.get_schema())( + validated_input_col, + validated_page_col, + ), ) else: # Pandas mode: use pandas_udf @@ -153,6 +161,7 @@ def _transform(self, dataset): out_col, pandas_udf(self.transform_udf_pandas, TextChunks.get_schema())( validated_input_col, + validated_page_col, lit(params), ), ) diff --git a/scaledp/models/splitters/TextSplitter.py b/scaledp/models/splitters/TextSplitter.py index 46875f8..73b5297 100644 --- a/scaledp/models/splitters/TextSplitter.py +++ b/scaledp/models/splitters/TextSplitter.py @@ -31,7 +31,7 @@ def _get_splitter(self): chunk_size = self.getOrDefault("chunk_size") return SemanticTextSplitter(chunk_size) - def split(self, document: Document) -> TextChunks: + def split(self, document: Document, page_number: int) -> TextChunks: start_time = time.time() try: splitter = self._get_splitter() @@ -44,6 +44,7 @@ def split(self, document: Document) -> TextChunks: return TextChunks( path=document.path, chunks=chunks, + page=page_number, exception=exception, processing_time=processing_time, ) diff --git a/scaledp/schemas/EmbeddingsOutput.py b/scaledp/schemas/EmbeddingsOutput.py index 3e6abab..65e522f 100644 --- a/scaledp/schemas/EmbeddingsOutput.py +++ b/scaledp/schemas/EmbeddingsOutput.py @@ -8,6 +8,8 @@ @dataclass(order=True) class EmbeddingsOutput: path: Optional[str] + page: Optional[int] + text: Optional[str] data: Optional[list[float]] type: Optional[str] exception: Optional[str] = "" diff --git a/scaledp/schemas/TextChunks.py b/scaledp/schemas/TextChunks.py index b424c5d..2d9ac43 100644 --- a/scaledp/schemas/TextChunks.py +++ b/scaledp/schemas/TextChunks.py @@ -7,6 +7,7 @@ @dataclass(order=True) class TextChunks: path: Optional[str] + page: Optional[int] chunks: Optional[list[str]] exception: Optional[str] = "" processing_time: Optional[float] = 0.0 diff --git a/tests/conftest.py b/tests/conftest.py index 041abc9..a21224f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -247,16 +247,20 @@ def document_df(spark_session, resource_path_root): df = spark_session.read.text(text_path, wholetext=True) # Create Document struct from text and path - return df.withColumn( - "document", - struct( - lit(text_path).alias("path"), - col("value").alias("text"), - lit("text").alias("type"), - lit([]).cast(ArrayType(StructType([]))).alias("bboxes"), - lit("").alias("exception"), - ), - ).select("document") + return ( + df.withColumn( + "document", + struct( + lit(text_path).alias("path"), + col("value").alias("text"), + lit("text").alias("type"), + lit([]).cast(ArrayType(StructType([]))).alias("bboxes"), + lit("").alias("exception"), + ), + ) + .withColumn("page_number", lit(0)) + .select("document", "page_number") + ) @pytest.fixture @@ -265,16 +269,16 @@ def df_text_chunks(spark_session): from pyspark.sql.functions import struct chunks_data = [ - ("file1.txt", ["hello world", "this is a test"], "", 1.0), - ("file2.txt", ["another chunk", "more text"], "", 2.0), - ("file3.txt", ["final chunk"], "", 0.5), + ("file1.txt", 0, ["hello world", "this is a test"], "", 1.0), + ("file2.txt", 0, ["another chunk", "more text"], "", 2.0), + ("file3.txt", 0, ["final chunk"], "", 0.5), ] return spark_session.createDataFrame( chunks_data, schema=TextChunks.get_schema(), ).select( - struct("path", "chunks", "exception", "processing_time").alias( + struct("path", "page", "chunks", "exception", "processing_time").alias( "text_chunks_col", ), ) diff --git a/tests/models/splitters/test_text_splitter.py b/tests/models/splitters/test_text_splitter.py index 163f9ee..afd7ee0 100644 --- a/tests/models/splitters/test_text_splitter.py +++ b/tests/models/splitters/test_text_splitter.py @@ -37,13 +37,14 @@ def test_text_splitter_split_method(): document = Document(path="test.txt", text=long_text, type="text", bboxes=[]) # Split the document - result = text_splitter.split(document) + result = text_splitter.split(document, 0) # Verify the result assert result.path == "test.txt" assert result.exception == "" assert len(result.chunks) > 1 assert result.processing_time > 0 + assert result.page == 0 # Verify all chunks are strings assert all(isinstance(chunk, str) for chunk in result.chunks) @@ -60,7 +61,7 @@ def test_text_splitter_split_method_with_exception(): document = Document(path="test.txt", text=None, type="text", bboxes=[]) # Split the document - result = text_splitter.split(document) + result = text_splitter.split(document, 0) # Verify the result contains exception information assert result.path == "test.txt" diff --git a/tutorials b/tutorials index 2cd3d33..c9dadc0 160000 --- a/tutorials +++ b/tutorials @@ -1 +1 @@ -Subproject commit 2cd3d33e1fbc866571061087a32be67f8798df1c +Subproject commit c9dadc091aa36a98c8ec3749b7fbe87d0adef6a9