From f9affc07561f2332bc9269bb8305d2b6e10f26d4 Mon Sep 17 00:00:00 2001 From: mykola Date: Thu, 20 Nov 2025 18:03:49 +0300 Subject: [PATCH 1/3] feat: Added support pandas udf for TextSplitter --- scaledp/models/splitters/BaseTextSplitter.py | 30 ++++++-------------- tests/models/splitters/test_text_splitter.py | 12 +++----- 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/scaledp/models/splitters/BaseTextSplitter.py b/scaledp/models/splitters/BaseTextSplitter.py index dc61650..806b043 100644 --- a/scaledp/models/splitters/BaseTextSplitter.py +++ b/scaledp/models/splitters/BaseTextSplitter.py @@ -78,7 +78,7 @@ def transform_udf(self, document_struct): @classmethod def transform_udf_pandas( cls, - documents: pd.Series, + documents: pd.DataFrame, params: pd.Series, ) -> pd.DataFrame: """ @@ -94,22 +94,16 @@ def transform_udf_pandas( params_dict = json.loads(params.iloc[0]) splitter = cls(**params_dict) results = [] - for doc_row in documents: + for _, doc_row in documents.iterrows(): # Convert Row to Document # When using pandas_udf with Arrow, the struct comes as # a Row object with field attributes try: - if isinstance(doc_row, Document): - doc = doc_row - else: - # Try to get attributes from Row object - doc = Document( - path=doc_row.path, - text=doc_row.text, - type=doc_row.type, - bboxes=doc_row.bboxes, - exception=getattr(doc_row, "exception", ""), - ) + doc = ( + doc_row + if isinstance(doc_row, Document) + else Document(**doc_row.to_dict()) + ) output = splitter.split(doc) except (AttributeError, TypeError, Exception) as e: # If something goes wrong, create an error result @@ -119,16 +113,8 @@ def transform_udf_pandas( exception=str(e), processing_time=0.0, ) - # Convert to dict to ensure proper schema - results.append( - { - "path": output.path, - "chunks": output.chunks, - "exception": output.exception, - "processing_time": output.processing_time, - }, - ) + results.append(output) return pd.DataFrame(results) def _transform(self, dataset): diff --git a/tests/models/splitters/test_text_splitter.py b/tests/models/splitters/test_text_splitter.py index 6b118ad..163f9ee 100644 --- a/tests/models/splitters/test_text_splitter.py +++ b/tests/models/splitters/test_text_splitter.py @@ -1,4 +1,3 @@ -import pytest from pyspark.ml import PipelineModel from scaledp.models.splitters.TextSplitter import TextSplitter @@ -70,7 +69,7 @@ def test_text_splitter_split_method_with_exception(): assert result.processing_time > 0 -def test_text_splitter_pipeline(text_splitter_df): +def test_text_splitter_pipeline(document_df): """Test TextSplitter in a PySpark pipeline.""" # Initialize the TextSplitter stage text_splitter = TextSplitter( @@ -82,7 +81,7 @@ def test_text_splitter_pipeline(text_splitter_df): # Create a pipeline with the TextSplitter stage pipeline = PipelineModel(stages=[text_splitter]) - result_df = pipeline.transform(text_splitter_df) + result_df = pipeline.transform(document_df) # Cache the result for performance result = result_df.select("chunks").cache() @@ -107,10 +106,7 @@ def test_text_splitter_pipeline(text_splitter_df): assert row.chunks.processing_time > 0 -@pytest.mark.skip( - reason="pandas_udf schema handling needs refinement for Document structs", -) -def test_text_splitter_pipeline_pandas(text_splitter_df): +def test_text_splitter_pipeline_pandas(document_df): """Test TextSplitter with partitionMap (pandas mode).""" # Initialize the TextSplitter stage text_splitter = TextSplitter( @@ -123,7 +119,7 @@ def test_text_splitter_pipeline_pandas(text_splitter_df): # Create a pipeline with the TextSplitter stage pipeline = PipelineModel(stages=[text_splitter]) - result_df = pipeline.transform(text_splitter_df) + result_df = pipeline.transform(document_df) # Cache the result for performance result = result_df.select("chunks").cache() From bed7bf2665e8f3dc2fe82c37f8daac78db5b43ae Mon Sep 17 00:00:00 2001 From: mykola Date: Thu, 20 Nov 2025 18:04:38 +0300 Subject: [PATCH 2/3] feat: Added support TextChunks as input to TextEmbeddings --- scaledp/models/embeddings/TextEmbeddings.py | 146 ++++++++++++++++++ tests/conftest.py | 24 ++- .../models/embeddings/test_text_embeddings.py | 75 +++++++++ 3 files changed, 244 insertions(+), 1 deletion(-) diff --git a/scaledp/models/embeddings/TextEmbeddings.py b/scaledp/models/embeddings/TextEmbeddings.py index e5f8520..43c51c3 100644 --- a/scaledp/models/embeddings/TextEmbeddings.py +++ b/scaledp/models/embeddings/TextEmbeddings.py @@ -1,14 +1,17 @@ import json +from dataclasses import asdict from types import MappingProxyType from typing import Any import pandas as pd from pyspark import keyword_only +from pyspark.sql.types import ArrayType, StructType from sentence_transformers import SentenceTransformer from scaledp.enums import Device from scaledp.models.embeddings.BaseEmbeddings import BaseEmbeddings from scaledp.schemas.EmbeddingsOutput import EmbeddingsOutput +from scaledp.schemas.TextChunks import TextChunks class TextEmbeddings(BaseEmbeddings): @@ -34,11 +37,97 @@ def __init__(self, **kwargs: Any) -> None: self._set(**kwargs) self._model = None + def _transform(self, dataset): + """Override _transform to handle both raw text and TextChunks input.""" + from pyspark.sql.functions import explode, lit, pandas_udf, udf + + params = self.get_params() + out_col = self.getOutputCol() + input_col = self.getInputCol() + + if input_col not in dataset.columns: + raise ValueError(f"Column {input_col} not found in dataset") + + # Check if input column is TextChunks (Array type) before validation + is_chunks = self._is_text_chunks_column(dataset, input_col) + + in_col = self._validate(input_col, dataset) + + if is_chunks: + # Handle TextChunks input - use explode to create multiple rows + if not self.getPartitionMap(): + # Use UDF with explode + result = dataset.withColumn( + out_col, + explode( + udf( + self.transform_udf_chunks, + ArrayType(EmbeddingsOutput.get_schema()), + )(in_col), + ), + ) + else: + # Use pandas_udf with explode + if self.getNumPartitions() > 0: + if self.getPageCol() in dataset.columns: + dataset = dataset.repartition(self.getPageCol()) + elif self.getPathCol() in dataset.columns: + dataset = dataset.repartition(self.getPathCol()) + dataset = dataset.coalesce(self.getNumPartitions()) + + result = dataset.withColumn( + out_col, + explode( + pandas_udf( + self.transform_udf_pandas_chunks, + ArrayType(EmbeddingsOutput.get_schema()), + )(in_col, lit(params)), + ), + ) + elif not self.getPartitionMap(): + result = dataset.withColumn( + out_col, + udf(self.transform_udf, EmbeddingsOutput.get_schema())(in_col), + ) + else: + if self.getNumPartitions() > 0: + if self.getPageCol() in dataset.columns: + dataset = dataset.repartition(self.getPageCol()) + elif self.getPathCol() in dataset.columns: + dataset = dataset.repartition(self.getPathCol()) + dataset = dataset.coalesce(self.getNumPartitions()) + result = dataset.withColumn( + out_col, + pandas_udf( + self.transform_udf_pandas, + EmbeddingsOutput.get_schema(), + )( + in_col, + lit(params), + ), + ) + + if not self.getKeepInputData(): + result = result.drop(in_col) + return result + def get_model(self): if self._model is None: self._model = SentenceTransformer(self.getModel()) return self._model + def _is_text_chunks_column(self, dataset, col_name: str) -> bool: + """Check if the column contains TextChunks data (matches TextChunks schema).""" + try: + column_type = dataset.schema[col_name].dataType + # Check if it's a StructType matching TextChunks schema + if isinstance(column_type, StructType): + expected_schema = TextChunks.get_schema() + return column_type == expected_schema + return False + except (KeyError, AttributeError): + return False + def transform_udf(self, text: str): model = self.get_model() embedding = model.encode( @@ -53,6 +142,30 @@ def transform_udf(self, text: str): exception="", ) + def transform_udf_chunks(self, text_chunks: TextChunks): + """Transform TextChunks into embeddings, preserving path information.""" + if not text_chunks or not text_chunks.chunks: + return [] + + model = self.get_model() + embeddings = model.encode( + text_chunks.chunks, + batch_size=self.getBatchSize(), + device=self.getSTDevice(), + ) + results = [] + for embedding in embeddings: + results.append( + EmbeddingsOutput( + path=text_chunks.path or "memory", + data=embedding.tolist(), + type="text_chunk", + exception=text_chunks.exception or "", + processing_time=text_chunks.processing_time or 0.0, + ), + ) + return results + @staticmethod def transform_udf_pandas(texts: pd.Series, params: pd.Series) -> pd.DataFrame: params = json.loads(params[0]) @@ -73,3 +186,36 @@ def transform_udf_pandas(texts: pd.Series, params: pd.Series) -> pd.DataFrame: ), ) return pd.DataFrame(results) + + @staticmethod + def transform_udf_pandas_chunks( + chunks_df: pd.DataFrame, + params: pd.Series, + ) -> pd.Series: + """Transform TextChunks into embeddings using pandas_udf, preserving path information.""" + params = json.loads(params.iloc[0]) + model = SentenceTransformer(params["model"]) + + results = [] + for _, row in chunks_df.iterrows(): + if len(row["chunks"]): + embeddings = model.encode( + row["chunks"], + batch_size=params["batchSize"], + device="cpu" if params["device"] == Device.CPU.value else "cuda", + ) + emb_results = [] + for embedding in embeddings: + emb_results.append( + asdict( + EmbeddingsOutput( + path=row.get("path") or "memory", + data=embedding.tolist(), + type="text_chunk", + exception=row.get("exception") or "", + processing_time=row.get("processing_time") or 0.0, + ), + ), + ) + results.append(emb_results) + return pd.Series(results) diff --git a/tests/conftest.py b/tests/conftest.py index be73c26..041abc9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from scaledp.image.DataToImage import DataToImage from scaledp.pipeline.PandasPipeline import pathSparkFunctions, unpathSparkFunctions from scaledp.schemas.Image import Image +from scaledp.schemas.TextChunks import TextChunks @pytest.fixture @@ -237,7 +238,7 @@ def text_df(spark_session, resource_path_root): @pytest.fixture -def text_splitter_df(spark_session, resource_path_root): +def document_df(spark_session, resource_path_root): """Fixture for text splitter tests with Document struct column.""" from pyspark.sql.functions import col, lit, struct from pyspark.sql.types import ArrayType, StructType @@ -256,3 +257,24 @@ def text_splitter_df(spark_session, resource_path_root): lit("").alias("exception"), ), ).select("document") + + +@pytest.fixture +def df_text_chunks(spark_session): + """Fixture for TextChunks schema with sample data.""" + from pyspark.sql.functions import struct + + chunks_data = [ + ("file1.txt", ["hello world", "this is a test"], "", 1.0), + ("file2.txt", ["another chunk", "more text"], "", 2.0), + ("file3.txt", ["final chunk"], "", 0.5), + ] + + return spark_session.createDataFrame( + chunks_data, + schema=TextChunks.get_schema(), + ).select( + struct("path", "chunks", "exception", "processing_time").alias( + "text_chunks_col", + ), + ) diff --git a/tests/models/embeddings/test_text_embeddings.py b/tests/models/embeddings/test_text_embeddings.py index 27a4941..49318e4 100644 --- a/tests/models/embeddings/test_text_embeddings.py +++ b/tests/models/embeddings/test_text_embeddings.py @@ -77,3 +77,78 @@ def test_text_embeddings_pipeline_pandas(text_df): for row in data: assert row.embeddings.data is not None assert len(row.embeddings.data) > 0 + + +def test_text_embeddings_with_text_chunks(df_text_chunks): + """Test TextEmbeddings with TextChunks schema input.""" + # Initialize the TextEmbeddings stage with TextChunks input + text_embeddings = TextEmbeddings( + model="all-MiniLM-L6-v2", + inputCol="text_chunks_col", + outputCol="embeddings", + device=Device.CPU.value, + batchSize=2, + ) + + # Create a pipeline with the TextEmbeddings stage + pipeline = PipelineModel(stages=[text_embeddings]) + + result_df = pipeline.transform(df_text_chunks) + + # Cache the result for performance + result = result_df.select("embeddings", "text_chunks_col").cache() + + # Collect the results + data = result.collect() + + # Should have 5 rows (2 + 2 + 1 from exploded chunks) + assert len(data) == 5 + + # Check that exceptions are empty + assert all(row.embeddings.exception == "" for row in data) + + # Verify the embeddings are not empty and have correct type + for row in data: + assert row.embeddings.data is not None + assert len(row.embeddings.data) > 0 + assert row.embeddings.type == "text_chunk" + # Verify path is preserved from TextChunks + assert row.embeddings.path in ["file1.txt", "file2.txt", "file3.txt"] + + +def test_text_embeddings_with_text_chunks_pandas(df_text_chunks): + """Test TextEmbeddings with TextChunks schema input using pandas_udf.""" + # Initialize the TextEmbeddings stage with partitionMap=True (pandas_udf) + text_embeddings = TextEmbeddings( + model="all-MiniLM-L6-v2", + inputCol="text_chunks_col", + outputCol="embeddings", + device=Device.CPU.value, + partitionMap=True, + batchSize=2, + ) + + # Create a pipeline with the TextEmbeddings stage + pipeline = PipelineModel(stages=[text_embeddings]) + + result_df = pipeline.transform(df_text_chunks) + + # Cache the result for performance + result = result_df.select("embeddings", "text_chunks_col").cache() + + # Collect the results + data = result.collect() + + # Should have 5 rows (2 + 2 + 1 from pandas_udf processing) + assert len(data) == 5 + + # Check that exceptions are empty + assert all(row.embeddings.exception == "" for row in data) + + # Verify the embeddings are not empty and have correct type + for row in data: + assert row.embeddings.data is not None + assert len(row.embeddings.data) > 0 + assert row.embeddings.type == "text_chunk" + # Verify path is preserved from TextChunks + assert row.embeddings.path in ["file1.txt", "file2.txt", "file3.txt"] From aae881757da5355429da696a84a6b29cf1588854 Mon Sep 17 00:00:00 2001 From: mykola Date: Thu, 20 Nov 2025 18:05:38 +0300 Subject: [PATCH 3/3] doc: Updated TextEmbedding docs --- docs/source/embeddings.md | 74 ++++- .../models/embeddings/TextEmbeddings.md | 142 +++++++++- docs/source/schemas/embeddings_output.md | 252 ++++++++++++++++++ docs/source/schemas/index.md | 36 ++- 4 files changed, 487 insertions(+), 17 deletions(-) create mode 100644 docs/source/schemas/embeddings_output.md diff --git a/docs/source/embeddings.md b/docs/source/embeddings.md index 2034c29..c800113 100644 --- a/docs/source/embeddings.md +++ b/docs/source/embeddings.md @@ -5,10 +5,80 @@ Embeddings This section provides an overview of the various embedding transformers available in ScaleDP for processing text and other data types. These transformers are designed to generate embeddings that can be used for tasks such as clustering, classification, and semantic similarity. +Embeddings are generated from text data using neural language models and produce high-dimensional vectors that capture semantic meaning. ScaleDP embeddings support both raw text and structured text chunks as input, with automatic schema detection. + ## Text Embeddings -* [**TextEmbeddings**](models/embeddings/TextEmbeddings.md) +* [**TextEmbeddings**](models/embeddings/TextEmbeddings.md) - Generate embeddings from raw text or TextChunks ## Base Embeddings -* [**BaseEmbeddings**](models/embeddings/BaseEmbeddings.md) +* [**BaseEmbeddings**](models/embeddings/BaseEmbeddings.md) - Base class for embedding transformers + +## Related Schemas + +The embedding transformers work with the following data schemas: + +* [**TextChunks Schema**](schemas/text_chunks.md) - Input schema for chunk-based embeddings (from text splitters) +* [**EmbeddingsOutput Schema**](schemas/embeddings_output.md) - Output schema containing embedding vectors and metadata + +## Quick Start + +### Embedding Raw Text + +```python +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings + +embeddings = TextEmbeddings( + inputCol="text", + outputCol="embeddings", + model="all-MiniLM-L6-v2", + batchSize=32, +) + +result = embeddings.transform(text_df) +``` + +### Embedding Text Chunks + +```python +from scaledp.models.splitters.TextSplitter import TextSplitter +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings +from pyspark.ml import Pipeline + +# Create a pipeline: split text then generate embeddings +pipeline = Pipeline(stages=[ + TextSplitter(inputCol="document", outputCol="chunks", chunk_size=500), + TextEmbeddings(inputCol="chunks", outputCol="embeddings", model="all-MiniLM-L6-v2"), +]) + +result = pipeline.fit(documents_df).transform(documents_df) +``` + +## Pipeline Integration + +``` +Document Input + ↓ +Text Splitter (optional) + ↓ +TextChunks or Raw Text + ↓ +TextEmbeddings Transformer + ↓ +EmbeddingsOutput + ↓ +Downstream Applications (Search, Classification, Clustering, etc.) +``` + +## Output Format + +All embedding transformers produce output in the `EmbeddingsOutput` schema containing: + +- **path**: Source document path +- **data**: Embedding vector (list of floats) +- **type**: "text" (raw) or "text_chunk" (from chunks) +- **exception**: Error messages (if any) +- **processing_time**: Computation duration + +For detailed schema information, see [EmbeddingsOutput Schema Documentation](schemas/embeddings_output.md). diff --git a/docs/source/models/embeddings/TextEmbeddings.md b/docs/source/models/embeddings/TextEmbeddings.md index a168226..6075412 100644 --- a/docs/source/models/embeddings/TextEmbeddings.md +++ b/docs/source/models/embeddings/TextEmbeddings.md @@ -5,19 +5,25 @@ `TextEmbeddings` is a text embedding transformer based on the SentenceTransformer model. It is designed to efficiently generate embeddings for text data using a pre-trained model. The transformer is implemented as a PySpark ML transformer and can be integrated into Spark pipelines for scalable text embedding tasks. -## Usage Example +The transformer supports two input types: +- **Raw text**: Single text string per row +- **TextChunks schema**: Structured input with path, chunks (list of strings), exception, and processing_time fields + +## Usage Examples + +### Raw Text Input ```python from scaledp import TextEmbeddings, PipelineModel text_embeddings = TextEmbeddings( - inputCol="text", - outputCol="embeddings", - keepInputData=True, - model="all-MiniLM-L6-v2", - batchSize=1, - device="cpu", - ) + inputCol="text", + outputCol="embeddings", + keepInputData=True, + model="all-MiniLM-L6-v2", + batchSize=1, + device="cpu", +) # Transform the text dataframe through the embedding stage pipeline = PipelineModel(stages=[text_embeddings]) @@ -25,6 +31,31 @@ result = pipeline.transform(text_df) result.show() ``` +### TextChunks Schema Input + +```python +from scaledp import TextEmbeddings, PipelineModel +from scaledp.schemas.TextChunks import TextChunks + +# Assuming df has a column with TextChunks schema +text_embeddings = TextEmbeddings( + inputCol="text_chunks", + outputCol="embeddings", + keepInputData=True, + model="all-MiniLM-L6-v2", + batchSize=2, + device="cpu", +) + +# Transform the dataframe - automatically detects TextChunks schema +pipeline = PipelineModel(stages=[text_embeddings]) +result = pipeline.transform(df) + +# Each chunk in the TextChunks list generates one row with preserved metadata +# Embeddings include: path, data (embedding vector), type, exception, processing_time +result.show() +``` + ## Parameters | Parameter | Type | Description | Default | @@ -40,7 +71,102 @@ result.show() | pageCol | str | Page column | "page" | | pathCol | str | Path column | "path" | +## Input and Output Schemas + +### Input Types + +The transformer automatically detects and handles two input types: + +#### Raw Text +- **Type Detection**: StringType column +- **Behavior**: Each row generates one embedding +- **Output Type Field**: "text" + +#### TextChunks Schema +- **Type Detection**: StructType matching TextChunks schema +- **Behavior**: Each chunk generates one embedding row (using explode internally) +- **Output Type Field**: "text_chunk" +- **Metadata Preservation**: path, exception, processing_time preserved in output + +### Output Schema + +All outputs use the `EmbeddingsOutput` schema: + +| Field | Type | Description | +|------------------|---------------|------------------------------------------------------| +| path | Optional[str] | Source path (from TextChunks or "memory" for text) | +| data | Optional[list[float]] | Embedding vector | +| type | Optional[str] | "text" or "text_chunk" | +| exception | Optional[str] | Error message if generation failed | +| processing_time | Optional[float] | Processing duration in seconds | + +## Behavior + +### Raw Text Input +``` +Input: 1 row × 1 column + "This is a text sample" + +Transform → + +Output: 1 row × 1 column + EmbeddingsOutput( + path="memory", + data=[0.123, -0.456, ...], + type="text", + exception="", + processing_time=0.045 + ) +``` + +### TextChunks Input +``` +Input: 1 row × 1 column + TextChunks( + path="doc.txt", + chunks=["chunk 1", "chunk 2", "chunk 3"], + exception="", + processing_time=1.5 + ) + +Transform → + +Output: 3 rows × 1 column (one per chunk) + Row 1: EmbeddingsOutput(path="doc.txt", data=[...], type="text_chunk", ...) + Row 2: EmbeddingsOutput(path="doc.txt", data=[...], type="text_chunk", ...) + Row 3: EmbeddingsOutput(path="doc.txt", data=[...], type="text_chunk", ...) +``` + +## Schema Detection + +The transformer automatically detects input schema using `_is_text_chunks_column()`: +- Checks if column type is StructType +- Compares against TextChunks.get_schema() +- Falls back to raw text handling if not TextChunks + +## Performance Considerations + +### Batch Size +- Affects memory usage and inference speed +- Larger batches are faster but require more GPU memory +- Default: 1 (conservative) +- Recommended: 16-64 for GPUs, 2-8 for CPUs + +### Partition Map +- `partitionMap=False` (default): Uses regular UDF +- `partitionMap=True`: Uses pandas_udf for better performance +- Recommended: Enable for large datasets or when numPartitions > 1 + +### Parallelization +- Set `numPartitions` > 1 to distribute computation +- Pairs well with `partitionMap=True` for optimal performance +- Consider your cluster size and model size + ## Notes - The transformer uses the SentenceTransformer model for generating text embeddings. - Supports batch processing and distributed inference with Spark. +- Automatically detects input schema type (raw text vs TextChunks). +- When `partitionMap=True`, uses pandas_udf for better performance on large datasets. - Additional parameters can be set using the corresponding setter methods. +- See [EmbeddingsOutput Schema Documentation](../../schemas/embeddings_output.md) for detailed output schema information. +- See [TextChunks Schema Documentation](../../schemas/text_chunks.md) for detailed input schema information. diff --git a/docs/source/schemas/embeddings_output.md b/docs/source/schemas/embeddings_output.md new file mode 100644 index 0000000..8830435 --- /dev/null +++ b/docs/source/schemas/embeddings_output.md @@ -0,0 +1,252 @@ +(EmbeddingsOutput)= +# EmbeddingsOutput Schema + +## Overview + +`EmbeddingsOutput` is a structured data schema that represents the output of text embedding operations. It contains the generated embedding vector along with metadata about the source document, processing information, and error tracking. This schema is produced by embedding transformers and serves as the standard output format for all embedding generation tasks. + +The EmbeddingsOutput schema maintains traceability by preserving the source path and captures processing details for debugging, monitoring, and downstream analysis. + +## Schema Structure + +```python +from scaledp.schemas.EmbeddingsOutput import EmbeddingsOutput +from typing import Optional + +# EmbeddingsOutput dataclass definition +@dataclass(order=True) +class EmbeddingsOutput: + path: Optional[str] # Source document or chunk path + data: Optional[list[float]] # Embedding vector + type: Optional[str] # Type of input ("text" or "text_chunk") + exception: Optional[str] = "" # Error message if any + processing_time: Optional[float] = 0.0 # Processing duration in seconds +``` + +## Fields + +| Field | Type | Required | Description | +|------------------|-------------------|----------|------------------------------------------------| +| path | Optional[str] | No | Source document file path or "memory" | +| data | Optional[list[float]]| No | Embedding vector (list of floats) | +| type | Optional[str] | No | Input type: "text" or "text_chunk" | +| exception | Optional[str] | No | Error message if embedding generation failed | +| processing_time | Optional[float] | No | Time taken to generate embedding (in seconds) | + +## Usage Examples + +### Basic EmbeddingsOutput Creation + +```python +from scaledp.schemas.EmbeddingsOutput import EmbeddingsOutput + +# Successful embedding result +result = EmbeddingsOutput( + path="/documents/file.txt", + data=[0.123, -0.456, 0.789, ...], # Embedding vector + type="text", + exception="", + processing_time=0.045 +) + +# Embedding from text chunks +chunk_result = EmbeddingsOutput( + path="/documents/file.txt", + data=[0.234, -0.567, 0.890, ...], + type="text_chunk", + exception="", + processing_time=0.032 +) +``` + +## PySpark Schema + +```python +from scaledp.schemas.EmbeddingsOutput import EmbeddingsOutput + +schema = EmbeddingsOutput.get_schema() +# StructType([ +# StructField('path', StringType(), True), +# StructField('data', ArrayType(DoubleType()), True), +# StructField('type', StringType(), True), +# StructField('exception', StringType(), True), +# StructField('processing_time', DoubleType(), True) +# ]) +``` + +## Processing Pipeline Integration + +### With TextEmbeddings - Raw Text Input + +```python +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings + +# TextEmbeddings produces EmbeddingsOutput from raw text +text_embeddings = TextEmbeddings( + inputCol="text", + outputCol="embeddings", # Output is EmbeddingsOutput schema + model="all-MiniLM-L6-v2", +) + +result_df = text_embeddings.transform(input_df) +# result_df has column "embeddings" of type EmbeddingsOutput +# Each row has type="text" +``` + +### With TextEmbeddings - TextChunks Input + +```python +from scaledp.models.splitters.TextSplitter import TextSplitter +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings + +# Split text into chunks +splitter = TextSplitter( + inputCol="document", + outputCol="chunks", # Output is TextChunks + chunk_size=500, +) + +# Generate embeddings from chunks +embeddings = TextEmbeddings( + inputCol="chunks", + outputCol="embeddings", # Output is EmbeddingsOutput + model="all-MiniLM-L6-v2", +) + +# Create pipeline +from pyspark.ml import Pipeline +pipeline = Pipeline(stages=[splitter, embeddings]) +result_df = pipeline.fit(input_df).transform(input_df) + +# result_df has column "embeddings" of type EmbeddingsOutput +# Each chunk generates one embedding row with type="text_chunk" +# Path metadata is preserved from TextChunks +``` + +## Related Schemas + +- [TextChunks](./text_chunks.md) - Input schema for chunk-based embeddings +- [Document](./document.md) - Input schema for raw text embeddings + +## See Also + +- [TextEmbeddings Transformer](../models/embeddings/TextEmbeddings.md) +- [Text Splitters](../models/splitters/index.md) +- [Embeddings](../models/embeddings.md) + +## Best Practices + +### 1. Always Check for Errors +```python +# Good: check exception field +successful = df.filter(df.embeddings.exception == "") + +# Process only successful embeddings +vectors = successful.select("embeddings.data") +``` + +### 2. Preserve Path Information +```python +# Good: path is automatically preserved from TextChunks +# Use it for traceability +embeddings_with_path = result_df.select( + "embeddings.path", + "embeddings.data", + "embeddings.type" +) +``` + +### 3. Monitor Processing Time +```python +# Good: track embedding generation performance +stats = result_df.select("embeddings.processing_time").describe() +stats.show() +``` + +### 4. Use Appropriate Batch Sizes +```python +# Good: adjust batch size for your dataset and hardware +embeddings = TextEmbeddings( + inputCol="chunks", + outputCol="embeddings", + batchSize=32, # Tune based on GPU memory + model="all-MiniLM-L6-v2", +) +``` + +## Integration Examples + +### Full Text Processing Pipeline + +```python +from pyspark.ml import Pipeline +from scaledp.models.splitters.TextSplitter import TextSplitter +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings + +# Create complete pipeline +stages = [ + TextSplitter( + inputCol="document", + outputCol="chunks", + chunk_size=500, + ), + TextEmbeddings( + inputCol="chunks", + outputCol="embeddings", + model="all-MiniLM-L6-v2", + batchSize=32, + ), +] + +pipeline = Pipeline(stages=stages) +result_df = pipeline.fit(documents_df).transform(documents_df) + +# Access embeddings +result_df.select( + "embeddings.path", + "embeddings.data", + "embeddings.type", + "embeddings.processing_time" +).show() +``` + +### Batch Embedding Generation + +```python +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings + +# Configure for large-scale processing +embeddings = TextEmbeddings( + inputCol="text", + outputCol="embeddings", + model="all-MiniLM-L6-v2", + batchSize=64, + partitionMap=True, # Use pandas_udf for distributed processing + numPartitions=8, +) + +# Process large dataset +result_df = embeddings.transform(input_df) + +# Save embeddings for later use +result_df.write.mode("overwrite").parquet("s3://bucket/embeddings/") +``` + +### Embedding Storage and Retrieval + +```python +from pyspark.sql.functions import col + +# Save embeddings to vector database or storage +embeddings_storage = result_df.select( + col("embeddings.path").alias("document_id"), + col("embeddings.data").alias("vector"), + col("embeddings.type").alias("embedding_type"), +) + +embeddings_storage.write.format("parquet").mode("overwrite").save("embeddings/") + +# Load and retrieve +loaded = spark.read.parquet("embeddings/") +loaded.filter(col("embedding_type") == "text_chunk").show() +``` \ No newline at end of file diff --git a/docs/source/schemas/index.md b/docs/source/schemas/index.md index cebe89c..18e1e56 100644 --- a/docs/source/schemas/index.md +++ b/docs/source/schemas/index.md @@ -28,6 +28,18 @@ Represents the output of text splitting operations. **Usage:** Output from text splitters, input for embeddings. +### [EmbeddingsOutput](./embeddings_output.md) +Represents the output of text embedding operations. + +**Key Fields:** +- `path` - Source document or chunk path +- `data` - Embedding vector (list of floats) +- `type` - Input type ("text" or "text_chunk") +- `exception` - Error message if any +- `processing_time` - Processing duration + +**Usage:** Output from embedding transformers, input for similarity search and vector databases. + ### [Box](./box.md) Represents a bounding box with position and text information. @@ -62,6 +74,13 @@ TextChunks ├── chunks: list[str] ├── exception: str └── processing_time: float + +EmbeddingsOutput +├── path: str +├── data: list[float] +├── type: str +├── exception: str +└── processing_time: float ``` ## Processing Pipeline Overview @@ -73,18 +92,21 @@ Text Splitter ↓ TextChunks (Split Results) ↓ -Embeddings / NER / Other Processing +TextEmbeddings + ↓ +EmbeddingsOutput (Embeddings) ↓ -Results +Vector DB / Similarity Search / Other Processing ``` ## Quick Reference -| Schema | Purpose | Input To | Output From | -|-------------|------------------------------|-------------------|-----------------| -| Document | Represents document data | Text Splitter, NER | PDF Reader | -| TextChunks | Represents text chunks | Embeddings, Search | Text Splitter | -| Box | Represents spatial region | Layout Analysis | OCR, Detector | +| Schema | Purpose | Input To | Output From | +|------------------|------------------------------|-------------------|-------------------| +| Document | Represents document data | Text Splitter, NER | PDF Reader | +| TextChunks | Represents text chunks | Embeddings, Search | Text Splitter | +| EmbeddingsOutput | Represents embeddings | Vector DB, Search | TextEmbeddings | +| Box | Represents spatial region | Layout Analysis | OCR, Detector | ## Common Operations