Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
## [unreleased]
## [0.3.0] - 02.11.2025

### 🚀 Features

- Added TextEmbeddings transformer, for compute embedding using SentenceTransformers
- Added [TextEmbeddings](https://scaledp.stabrise.com/en/latest/models/embeddings/TextEmbeddings.html) transformer, for compute embedding using SentenceTransformers
- Added BaseTextSplitter and [TextSplitter](https://scaledp.stabrise.com/en/latest/models/splitters/text_splitter.html) for semantic splitting text
- Added support pandas udf for TextSplitter
- Added support TextChunks as input to TextEmbeddings

### 📚 Documentation

- Added TextEmbedding and TextSplitter docs

### 📘 Jupyter Notebooks

- [TextSplitterAndEmbeddings.ipynb](https://github.com/StabRise/ScaleDP-Tutorials/blob/master/embeddings/1.TextSplitterAndEmbeddings.ipynb
) - Read pdf documents, split text into chunks and compute embeddings



## [0.2.6] - 19.11.2025
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scaledp"
version = "0.2.6"
version = "0.3.0"
description = "ScaleDP is a library for processing documents and images using Apache Spark and LLMs"
authors = ["Mykola Melnyk <mykola@stabrise.com>"]
repository = "https://github.com/StabRise/scaledp"
Expand Down
4 changes: 4 additions & 0 deletions scaledp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from scaledp.models.detectors.SignatureDetector import SignatureDetector
from scaledp.models.detectors.YoloDetector import YoloDetector
from scaledp.models.detectors.YoloOnnxDetector import YoloOnnxDetector
from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings
from scaledp.models.extractors.DSPyExtractor import DSPyExtractor
from scaledp.models.extractors.LLMExtractor import LLMExtractor
from scaledp.models.extractors.LLMVisualExtractor import LLMVisualExtractor
Expand All @@ -29,6 +30,7 @@
from scaledp.models.recognizers.SuryaOcr import SuryaOcr
from scaledp.models.recognizers.TesseractOcr import TesseractOcr
from scaledp.models.recognizers.TesseractRecognizer import TesseractRecognizer
from scaledp.models.splitters.TextSplitter import TextSplitter
from scaledp.pdf.PdfAddTextLayer import PdfAddTextLayer
from scaledp.pdf.PdfAssembler import PdfAssembler
from scaledp.pdf.PdfDataToDocument import PdfDataToDocument
Expand Down Expand Up @@ -214,6 +216,8 @@ def ScaleDPSession(
"TesseractOcr",
"Ner",
"TextToDocument",
"TextSplitter",
"TextEmbeddings",
"LayoutDetector",
"PipelineModel",
"SuryaOcr",
Expand Down
47 changes: 39 additions & 8 deletions scaledp/models/embeddings/TextEmbeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
from dataclasses import asdict
from types import MappingProxyType
from typing import Any
Expand Down Expand Up @@ -57,7 +58,7 @@ def _transform(self, dataset):
# Handle TextChunks input - use explode to create multiple rows
if not self.getPartitionMap():
# Use UDF with explode
result = dataset.withColumn(
result = dataset.select(in_col).withColumn(
out_col,
explode(
udf(
Expand All @@ -75,7 +76,7 @@ def _transform(self, dataset):
dataset = dataset.repartition(self.getPathCol())
dataset = dataset.coalesce(self.getNumPartitions())

result = dataset.withColumn(
result = dataset.select(in_col).withColumn(
out_col,
explode(
pandas_udf(
Expand Down Expand Up @@ -130,38 +131,55 @@ def _is_text_chunks_column(self, dataset, col_name: str) -> bool:

def transform_udf(self, text: str):
model = self.get_model()
start_time = time.time()
embedding = model.encode(
text,
batch_size=self.getBatchSize(),
device=self.getSTDevice(),
)
processing_time = time.time() - start_time
return EmbeddingsOutput(
path="memory",
page=0,
text=text,
data=embedding.tolist(),
type="text",
processing_time=processing_time,
exception="",
)

def transform_udf_chunks(self, text_chunks: TextChunks):
"""Transform TextChunks into embeddings, preserving path information."""
"""Transform TextChunks into embeddings, preserving path information
and per-item processing time (batched).
"""
if not text_chunks or not text_chunks.chunks:
return []

start_time = time.time()
model = self.get_model()
embeddings = model.encode(
text_chunks.chunks,
batch_size=self.getBatchSize(),
device=self.getSTDevice(),
)
total_processing_time = time.time() - start_time
per_item_time = (
total_processing_time / len(text_chunks.chunks)
if text_chunks.chunks
else 0.0
)

results = []
for embedding in embeddings:
for i, embedding in enumerate(embeddings):
results.append(
EmbeddingsOutput(
path=text_chunks.path or "memory",
data=embedding.tolist(),
page=text_chunks.page,
text=text_chunks.chunks[i],
type="text_chunk",
exception=text_chunks.exception or "",
processing_time=text_chunks.processing_time or 0.0,
processing_time=per_item_time,
),
)
return results
Expand All @@ -170,18 +188,26 @@ def transform_udf_chunks(self, text_chunks: TextChunks):
def transform_udf_pandas(texts: pd.Series, params: pd.Series) -> pd.DataFrame:
params = json.loads(params[0])
model = SentenceTransformer(params["model"])
start_time = time.time()
embeddings = model.encode(
texts.tolist(),
batch_size=params["batchSize"],
device="cpu" if params["device"] == Device.CPU.value else "cuda",
)
total_processing_time = time.time() - start_time
per_item_time = (
total_processing_time / texts.shape[0] if texts is not None else 0.0
)
results = []
for embedding in embeddings:
for i, embedding in enumerate(embeddings):
results.append(
EmbeddingsOutput(
path="memory",
page=0,
text=texts.iloc[i],
data=embedding.tolist(),
type="text",
processing_time=per_item_time,
exception="",
),
)
Expand All @@ -199,21 +225,26 @@ def transform_udf_pandas_chunks(
results = []
for _, row in chunks_df.iterrows():
if len(row["chunks"]):
start_time = time.time()
embeddings = model.encode(
row["chunks"],
batch_size=params["batchSize"],
device="cpu" if params["device"] == Device.CPU.value else "cuda",
)
total_processing_time = time.time() - start_time
per_item_time = total_processing_time / len(row["chunks"])
emb_results = []
for embedding in embeddings:
for i, embedding in enumerate(embeddings):
emb_results.append(
asdict(
EmbeddingsOutput(
path=row.get("path") or "memory",
data=embedding.tolist(),
text=row["chunks"][i],
page=row.get("page"),
type="text_chunk",
exception=row.get("exception") or "",
processing_time=row.get("processing_time") or 0.0,
processing_time=per_item_time or 0.0,
),
),
)
Expand Down
2 changes: 2 additions & 0 deletions scaledp/models/splitters/BaseSplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HasKeepInputData,
HasNumPartitions,
HasOutputCol,
HasPageCol,
HasPartitionMap,
HasWhiteList,
)
Expand All @@ -24,6 +25,7 @@ class BaseSplitter(
HasPartitionMap,
HasChunkSize,
HasChunkOverlap,
HasPageCol,
ABC,
):
"""
Expand Down
21 changes: 15 additions & 6 deletions scaledp/models/splitters/BaseTextSplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BaseTextSplitter(
"chunk_overlap": 0,
"numPartitions": 1,
"partitionMap": False,
"pageCol": "page_number",
},
)

Expand All @@ -50,18 +51,19 @@ def get_params(self):
return json.dumps({k.name: v for k, v in self.extractParamMap().items()})

@abstractmethod
def split(self, document: Document) -> TextChunks:
def split(self, document: Document, pagen_number: int) -> TextChunks:
"""
Split a document into chunks.

Args:
document: The document to split
pagen_number: The page number of the document

Returns:
TextChunks object containing the chunks and metadata
"""

def transform_udf(self, document_struct):
def transform_udf(self, document_struct, page_number):
"""
Transform UDF that splits text into chunks.

Expand All @@ -72,13 +74,14 @@ def transform_udf(self, document_struct):
TextChunks object containing the chunks
"""
# document_struct is already a Document object
result = self.split(document_struct)
result = self.split(document_struct, page_number)
return result

@classmethod
def transform_udf_pandas(
cls,
documents: pd.DataFrame,
page_numbers: pd.Series,
params: pd.Series,
) -> pd.DataFrame:
"""
Expand All @@ -94,7 +97,7 @@ def transform_udf_pandas(
params_dict = json.loads(params.iloc[0])
splitter = cls(**params_dict)
results = []
for _, doc_row in documents.iterrows():
for i, doc_row in documents.iterrows():
# Convert Row to Document
# When using pandas_udf with Arrow, the struct comes as
# a Row object with field attributes
Expand All @@ -104,7 +107,7 @@ def transform_udf_pandas(
if isinstance(doc_row, Document)
else Document(**doc_row.to_dict())
)
output = splitter.split(doc)
output = splitter.split(doc, page_numbers.iloc[i])
except (AttributeError, TypeError, Exception) as e:
# If something goes wrong, create an error result
output = TextChunks(
Expand All @@ -130,19 +133,24 @@ def _transform(self, dataset):
params = self.get_params()
out_col = self.getOutputCol()
input_col = self.getInputCol()
page_col = self.getPageCol()

# Validate input column exists
if input_col not in dataset.columns:
raise ValueError(f"Column {input_col} not found in dataset")

# Validate input column
validated_input_col = self._validate(input_col, dataset)
validated_page_col = self._validate(page_col, dataset)

if not self.getPartitionMap():
# Regular mode: use UDF
result = dataset.withColumn(
out_col,
udf(self.transform_udf, TextChunks.get_schema())(validated_input_col),
udf(self.transform_udf, TextChunks.get_schema())(
validated_input_col,
validated_page_col,
),
)
else:
# Pandas mode: use pandas_udf
Expand All @@ -153,6 +161,7 @@ def _transform(self, dataset):
out_col,
pandas_udf(self.transform_udf_pandas, TextChunks.get_schema())(
validated_input_col,
validated_page_col,
lit(params),
),
)
Expand Down
3 changes: 2 additions & 1 deletion scaledp/models/splitters/TextSplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_splitter(self):
chunk_size = self.getOrDefault("chunk_size")
return SemanticTextSplitter(chunk_size)

def split(self, document: Document) -> TextChunks:
def split(self, document: Document, page_number: int) -> TextChunks:
start_time = time.time()
try:
splitter = self._get_splitter()
Expand All @@ -44,6 +44,7 @@ def split(self, document: Document) -> TextChunks:
return TextChunks(
path=document.path,
chunks=chunks,
page=page_number,
exception=exception,
processing_time=processing_time,
)
2 changes: 2 additions & 0 deletions scaledp/schemas/EmbeddingsOutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
@dataclass(order=True)
class EmbeddingsOutput:
path: Optional[str]
page: Optional[int]
text: Optional[str]
data: Optional[list[float]]
type: Optional[str]
exception: Optional[str] = ""
Expand Down
1 change: 1 addition & 0 deletions scaledp/schemas/TextChunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@dataclass(order=True)
class TextChunks:
path: Optional[str]
page: Optional[int]
chunks: Optional[list[str]]
exception: Optional[str] = ""
processing_time: Optional[float] = 0.0
Expand Down
32 changes: 18 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,20 @@ def document_df(spark_session, resource_path_root):
df = spark_session.read.text(text_path, wholetext=True)

# Create Document struct from text and path
return df.withColumn(
"document",
struct(
lit(text_path).alias("path"),
col("value").alias("text"),
lit("text").alias("type"),
lit([]).cast(ArrayType(StructType([]))).alias("bboxes"),
lit("").alias("exception"),
),
).select("document")
return (
df.withColumn(
"document",
struct(
lit(text_path).alias("path"),
col("value").alias("text"),
lit("text").alias("type"),
lit([]).cast(ArrayType(StructType([]))).alias("bboxes"),
lit("").alias("exception"),
),
)
.withColumn("page_number", lit(0))
.select("document", "page_number")
)


@pytest.fixture
Expand All @@ -265,16 +269,16 @@ def df_text_chunks(spark_session):
from pyspark.sql.functions import struct

chunks_data = [
("file1.txt", ["hello world", "this is a test"], "", 1.0),
("file2.txt", ["another chunk", "more text"], "", 2.0),
("file3.txt", ["final chunk"], "", 0.5),
("file1.txt", 0, ["hello world", "this is a test"], "", 1.0),
("file2.txt", 0, ["another chunk", "more text"], "", 2.0),
("file3.txt", 0, ["final chunk"], "", 0.5),
]

return spark_session.createDataFrame(
chunks_data,
schema=TextChunks.get_schema(),
).select(
struct("path", "chunks", "exception", "processing_time").alias(
struct("path", "page", "chunks", "exception", "processing_time").alias(
"text_chunks_col",
),
)
Loading