Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions docs/source/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,80 @@ Embeddings

This section provides an overview of the various embedding transformers available in ScaleDP for processing text and other data types. These transformers are designed to generate embeddings that can be used for tasks such as clustering, classification, and semantic similarity.

Embeddings are generated from text data using neural language models and produce high-dimensional vectors that capture semantic meaning. ScaleDP embeddings support both raw text and structured text chunks as input, with automatic schema detection.

## Text Embeddings

* [**TextEmbeddings**](models/embeddings/TextEmbeddings.md)
* [**TextEmbeddings**](models/embeddings/TextEmbeddings.md) - Generate embeddings from raw text or TextChunks

## Base Embeddings

* [**BaseEmbeddings**](models/embeddings/BaseEmbeddings.md)
* [**BaseEmbeddings**](models/embeddings/BaseEmbeddings.md) - Base class for embedding transformers

## Related Schemas

The embedding transformers work with the following data schemas:

* [**TextChunks Schema**](schemas/text_chunks.md) - Input schema for chunk-based embeddings (from text splitters)
* [**EmbeddingsOutput Schema**](schemas/embeddings_output.md) - Output schema containing embedding vectors and metadata

## Quick Start

### Embedding Raw Text

```python
from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings

embeddings = TextEmbeddings(
inputCol="text",
outputCol="embeddings",
model="all-MiniLM-L6-v2",
batchSize=32,
)

result = embeddings.transform(text_df)
```

### Embedding Text Chunks

```python
from scaledp.models.splitters.TextSplitter import TextSplitter
from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings
from pyspark.ml import Pipeline

# Create a pipeline: split text then generate embeddings
pipeline = Pipeline(stages=[
TextSplitter(inputCol="document", outputCol="chunks", chunk_size=500),
TextEmbeddings(inputCol="chunks", outputCol="embeddings", model="all-MiniLM-L6-v2"),
])

result = pipeline.fit(documents_df).transform(documents_df)
```

## Pipeline Integration

```
Document Input
Text Splitter (optional)
TextChunks or Raw Text
TextEmbeddings Transformer
EmbeddingsOutput
Downstream Applications (Search, Classification, Clustering, etc.)
```

## Output Format

All embedding transformers produce output in the `EmbeddingsOutput` schema containing:

- **path**: Source document path
- **data**: Embedding vector (list of floats)
- **type**: "text" (raw) or "text_chunk" (from chunks)
- **exception**: Error messages (if any)
- **processing_time**: Computation duration

For detailed schema information, see [EmbeddingsOutput Schema Documentation](schemas/embeddings_output.md).
142 changes: 134 additions & 8 deletions docs/source/models/embeddings/TextEmbeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,57 @@

`TextEmbeddings` is a text embedding transformer based on the SentenceTransformer model. It is designed to efficiently generate embeddings for text data using a pre-trained model. The transformer is implemented as a PySpark ML transformer and can be integrated into Spark pipelines for scalable text embedding tasks.

## Usage Example
The transformer supports two input types:
- **Raw text**: Single text string per row
- **TextChunks schema**: Structured input with path, chunks (list of strings), exception, and processing_time fields

## Usage Examples

### Raw Text Input

```python
from scaledp import TextEmbeddings, PipelineModel

text_embeddings = TextEmbeddings(
inputCol="text",
outputCol="embeddings",
keepInputData=True,
model="all-MiniLM-L6-v2",
batchSize=1,
device="cpu",
)
inputCol="text",
outputCol="embeddings",
keepInputData=True,
model="all-MiniLM-L6-v2",
batchSize=1,
device="cpu",
)

# Transform the text dataframe through the embedding stage
pipeline = PipelineModel(stages=[text_embeddings])
result = pipeline.transform(text_df)
result.show()
```

### TextChunks Schema Input

```python
from scaledp import TextEmbeddings, PipelineModel
from scaledp.schemas.TextChunks import TextChunks

# Assuming df has a column with TextChunks schema
text_embeddings = TextEmbeddings(
inputCol="text_chunks",
outputCol="embeddings",
keepInputData=True,
model="all-MiniLM-L6-v2",
batchSize=2,
device="cpu",
)

# Transform the dataframe - automatically detects TextChunks schema
pipeline = PipelineModel(stages=[text_embeddings])
result = pipeline.transform(df)

# Each chunk in the TextChunks list generates one row with preserved metadata
# Embeddings include: path, data (embedding vector), type, exception, processing_time
result.show()
```

## Parameters

| Parameter | Type | Description | Default |
Expand All @@ -40,7 +71,102 @@ result.show()
| pageCol | str | Page column | "page" |
| pathCol | str | Path column | "path" |

## Input and Output Schemas

### Input Types

The transformer automatically detects and handles two input types:

#### Raw Text
- **Type Detection**: StringType column
- **Behavior**: Each row generates one embedding
- **Output Type Field**: "text"

#### TextChunks Schema
- **Type Detection**: StructType matching TextChunks schema
- **Behavior**: Each chunk generates one embedding row (using explode internally)
- **Output Type Field**: "text_chunk"
- **Metadata Preservation**: path, exception, processing_time preserved in output

### Output Schema

All outputs use the `EmbeddingsOutput` schema:

| Field | Type | Description |
|------------------|---------------|------------------------------------------------------|
| path | Optional[str] | Source path (from TextChunks or "memory" for text) |
| data | Optional[list[float]] | Embedding vector |
| type | Optional[str] | "text" or "text_chunk" |
| exception | Optional[str] | Error message if generation failed |
| processing_time | Optional[float] | Processing duration in seconds |

## Behavior

### Raw Text Input
```
Input: 1 row × 1 column
"This is a text sample"

Transform →

Output: 1 row × 1 column
EmbeddingsOutput(
path="memory",
data=[0.123, -0.456, ...],
type="text",
exception="",
processing_time=0.045
)
```

### TextChunks Input
```
Input: 1 row × 1 column
TextChunks(
path="doc.txt",
chunks=["chunk 1", "chunk 2", "chunk 3"],
exception="",
processing_time=1.5
)

Transform →

Output: 3 rows × 1 column (one per chunk)
Row 1: EmbeddingsOutput(path="doc.txt", data=[...], type="text_chunk", ...)
Row 2: EmbeddingsOutput(path="doc.txt", data=[...], type="text_chunk", ...)
Row 3: EmbeddingsOutput(path="doc.txt", data=[...], type="text_chunk", ...)
```

## Schema Detection

The transformer automatically detects input schema using `_is_text_chunks_column()`:
- Checks if column type is StructType
- Compares against TextChunks.get_schema()
- Falls back to raw text handling if not TextChunks

## Performance Considerations

### Batch Size
- Affects memory usage and inference speed
- Larger batches are faster but require more GPU memory
- Default: 1 (conservative)
- Recommended: 16-64 for GPUs, 2-8 for CPUs

### Partition Map
- `partitionMap=False` (default): Uses regular UDF
- `partitionMap=True`: Uses pandas_udf for better performance
- Recommended: Enable for large datasets or when numPartitions > 1

### Parallelization
- Set `numPartitions` > 1 to distribute computation
- Pairs well with `partitionMap=True` for optimal performance
- Consider your cluster size and model size

## Notes
- The transformer uses the SentenceTransformer model for generating text embeddings.
- Supports batch processing and distributed inference with Spark.
- Automatically detects input schema type (raw text vs TextChunks).
- When `partitionMap=True`, uses pandas_udf for better performance on large datasets.
- Additional parameters can be set using the corresponding setter methods.
- See [EmbeddingsOutput Schema Documentation](../../schemas/embeddings_output.md) for detailed output schema information.
- See [TextChunks Schema Documentation](../../schemas/text_chunks.md) for detailed input schema information.
Loading