From ed46b033f1df558b719011139f6921f7078ee21c Mon Sep 17 00:00:00 2001 From: mykola Date: Wed, 12 Nov 2025 08:29:23 +0300 Subject: [PATCH 1/3] feat: Added TextEmbeddings transformer, for compute embedding using SentenceTransformers --- docs/source/embeddings.md | 14 +++ docs/source/index.rst | 1 + .../models/embeddings/BaseEmbeddings.md | 39 ++++++ .../models/embeddings/TextEmbeddings.md | 46 ++++++++ poetry.lock | 111 +++++++++++++++++- pyproject.toml | 11 +- scaledp/models/embeddings/BaseEmbeddings.py | 75 ++++++++++++ scaledp/models/embeddings/TextEmbeddings.py | 75 ++++++++++++ scaledp/models/embeddings/__init__.py | 0 scaledp/params.py | 9 ++ scaledp/schemas/EmbeddingsOutput.py | 21 ++++ scaledp/utils/dataclass.py | 7 ++ tests/models/embeddings/__init__.py | 0 .../models/embeddings/test_text_embeddings.py | 79 +++++++++++++ 14 files changed, 482 insertions(+), 6 deletions(-) create mode 100644 docs/source/embeddings.md create mode 100644 docs/source/models/embeddings/BaseEmbeddings.md create mode 100644 docs/source/models/embeddings/TextEmbeddings.md create mode 100644 scaledp/models/embeddings/BaseEmbeddings.py create mode 100644 scaledp/models/embeddings/TextEmbeddings.py create mode 100644 scaledp/models/embeddings/__init__.py create mode 100644 scaledp/schemas/EmbeddingsOutput.py create mode 100644 tests/models/embeddings/__init__.py create mode 100644 tests/models/embeddings/test_text_embeddings.py diff --git a/docs/source/embeddings.md b/docs/source/embeddings.md new file mode 100644 index 0000000..2034c29 --- /dev/null +++ b/docs/source/embeddings.md @@ -0,0 +1,14 @@ +Embeddings +========== + +## Overview + +This section provides an overview of the various embedding transformers available in ScaleDP for processing text and other data types. These transformers are designed to generate embeddings that can be used for tasks such as clustering, classification, and semantic similarity. + +## Text Embeddings + +* [**TextEmbeddings**](models/embeddings/TextEmbeddings.md) + +## Base Embeddings + +* [**BaseEmbeddings**](models/embeddings/BaseEmbeddings.md) diff --git a/docs/source/index.rst b/docs/source/index.rst index 4c39230..df628ed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -46,6 +46,7 @@ Benefits of using ScaleDP pdf_processing.md detectors.md ocr.md + embeddings.md show_utils.md release_notes.md diff --git a/docs/source/models/embeddings/BaseEmbeddings.md b/docs/source/models/embeddings/BaseEmbeddings.md new file mode 100644 index 0000000..efc3048 --- /dev/null +++ b/docs/source/models/embeddings/BaseEmbeddings.md @@ -0,0 +1,39 @@ +(BaseEmbeddings)= +# BaseEmbeddings + +## Overview + +`BaseEmbeddings` is an abstract base class for embedding transformers in ScaleDP. It provides the foundational structure and common functionality for embedding models, enabling efficient and scalable embedding generation for various data types. Derived classes, such as `TextEmbeddings`, extend this base class to implement specific embedding logic. + +## Key Features + +- **Abstract Base Class**: Provides a common interface for embedding transformers. +- **PySpark Integration**: Designed to work seamlessly with PySpark for distributed data processing. +- **Customizable Parameters**: Supports a wide range of parameters for flexibility and customization. +- **Error Handling**: Includes validation for input columns and error propagation options. + +## Usage Example + +`BaseEmbeddings` is not intended to be used directly. Instead, it serves as a parent class for specific embedding transformers like `TextEmbeddings`. + +## Parameters + +| Parameter | Type | Description | Default | +|-------------------|---------|--------------------------------------------------|-----------------------------| +| inputCol | str | Input column name | N/A | +| outputCol | str | Output column name | N/A | +| keepInputData | bool | Whether to retain input data in the output | True | +| device | Device | Device for computation (CPU/GPU) | Device.CPU | +| model | str | Pre-trained model identifier | N/A | +| batchSize | int | Batch size for processing | 1 | +| numPartitions | int | Number of partitions for distributed processing | 1 | +| partitionMap | bool | Use partitioned mapping | False | +| pageCol | str | Page column | "page" | +| pathCol | str | Path column | "path" | + +## Notes + +- `BaseEmbeddings` provides the `_transform` method, which handles the core logic for applying transformations to a dataset. +- Derived classes must implement the `transform_udf` and `transform_udf_pandas` methods to define the specific embedding logic. +- The class includes validation for input columns to ensure compatibility with the dataset. + diff --git a/docs/source/models/embeddings/TextEmbeddings.md b/docs/source/models/embeddings/TextEmbeddings.md new file mode 100644 index 0000000..a168226 --- /dev/null +++ b/docs/source/models/embeddings/TextEmbeddings.md @@ -0,0 +1,46 @@ +(TextEmbeddings)= +# TextEmbeddings + +## Overview + +`TextEmbeddings` is a text embedding transformer based on the SentenceTransformer model. It is designed to efficiently generate embeddings for text data using a pre-trained model. The transformer is implemented as a PySpark ML transformer and can be integrated into Spark pipelines for scalable text embedding tasks. + +## Usage Example + +```python +from scaledp import TextEmbeddings, PipelineModel + +text_embeddings = TextEmbeddings( + inputCol="text", + outputCol="embeddings", + keepInputData=True, + model="all-MiniLM-L6-v2", + batchSize=1, + device="cpu", + ) + +# Transform the text dataframe through the embedding stage +pipeline = PipelineModel(stages=[text_embeddings]) +result = pipeline.transform(text_df) +result.show() +``` + +## Parameters + +| Parameter | Type | Description | Default | +|-------------------|---------|--------------------------------------------------|-----------------------------| +| inputCol | str | Input text column | "text" | +| outputCol | str | Output column for embeddings | "embeddings" | +| keepInputData | bool | Keep input data in output | True | +| model | str | Pre-trained model identifier | "all-MiniLM-L6-v2" | +| batchSize | int | Batch size for inference | 1 | +| device | Device | Inference device (CPU/GPU) | Device.CPU | +| numPartitions | int | Number of partitions | 1 | +| partitionMap | bool | Use partitioned mapping | False | +| pageCol | str | Page column | "page" | +| pathCol | str | Path column | "path" | + +## Notes +- The transformer uses the SentenceTransformer model for generating text embeddings. +- Supports batch processing and distributed inference with Spark. +- Additional parameters can be set using the corresponding setter methods. diff --git a/poetry.lock b/poetry.lock index 74da421..7b93403 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1687,6 +1687,18 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "joblib" +version = "1.5.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241"}, + {file = "joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55"}, +] + [[package]] name = "json5" version = "0.12.0" @@ -5162,6 +5174,62 @@ docs = ["PyWavelets (>=1.6)", "dask[array] (>=2023.2.0)", "intersphinx-registry optional = ["PyWavelets (>=1.6)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=1.1.1)", "dask[array] (>=2023.2.0)", "matplotlib (>=3.7)", "pooch (>=1.6.0)", "pyamg (>=5.2)", "scikit-learn (>=1.2)"] test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=8)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"] +[[package]] +name = "scikit-learn" +version = "1.7.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "scikit_learn-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b33579c10a3081d076ab403df4a4190da4f4432d443521674637677dc91e61f"}, + {file = "scikit_learn-1.7.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:36749fb62b3d961b1ce4fedf08fa57a1986cd409eff2d783bca5d4b9b5fce51c"}, + {file = "scikit_learn-1.7.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7a58814265dfc52b3295b1900cfb5701589d30a8bb026c7540f1e9d3499d5ec8"}, + {file = "scikit_learn-1.7.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a847fea807e278f821a0406ca01e387f97653e284ecbd9750e3ee7c90347f18"}, + {file = "scikit_learn-1.7.2-cp310-cp310-win_amd64.whl", hash = "sha256:ca250e6836d10e6f402436d6463d6c0e4d8e0234cfb6a9a47835bd392b852ce5"}, + {file = "scikit_learn-1.7.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7509693451651cd7361d30ce4e86a1347493554f172b1c72a39300fa2aea79e"}, + {file = "scikit_learn-1.7.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:0486c8f827c2e7b64837c731c8feff72c0bd2b998067a8a9cbc10643c31f0fe1"}, + {file = "scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89877e19a80c7b11a2891a27c21c4894fb18e2c2e077815bcade10d34287b20d"}, + {file = "scikit_learn-1.7.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8da8bf89d4d79aaec192d2bda62f9b56ae4e5b4ef93b6a56b5de4977e375c1f1"}, + {file = "scikit_learn-1.7.2-cp311-cp311-win_amd64.whl", hash = "sha256:9b7ed8d58725030568523e937c43e56bc01cadb478fc43c042a9aca1dacb3ba1"}, + {file = "scikit_learn-1.7.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8d91a97fa2b706943822398ab943cde71858a50245e31bc71dba62aab1d60a96"}, + {file = "scikit_learn-1.7.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:acbc0f5fd2edd3432a22c69bed78e837c70cf896cd7993d71d51ba6708507476"}, + {file = "scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5bf3d930aee75a65478df91ac1225ff89cd28e9ac7bd1196853a9229b6adb0b"}, + {file = "scikit_learn-1.7.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4d6e9deed1a47aca9fe2f267ab8e8fe82ee20b4526b2c0cd9e135cea10feb44"}, + {file = "scikit_learn-1.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:6088aa475f0785e01bcf8529f55280a3d7d298679f50c0bb70a2364a82d0b290"}, + {file = "scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7"}, + {file = "scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe"}, + {file = "scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f"}, + {file = "scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0"}, + {file = "scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c"}, + {file = "scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8"}, + {file = "scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a"}, + {file = "scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c"}, + {file = "scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c"}, + {file = "scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973"}, + {file = "scikit_learn-1.7.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fa8f63940e29c82d1e67a45d5297bdebbcb585f5a5a50c4914cc2e852ab77f33"}, + {file = "scikit_learn-1.7.2-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:f95dc55b7902b91331fa4e5845dd5bde0580c9cd9612b1b2791b7e80c3d32615"}, + {file = "scikit_learn-1.7.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9656e4a53e54578ad10a434dc1f993330568cfee176dff07112b8785fb413106"}, + {file = "scikit_learn-1.7.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96dc05a854add0e50d3f47a1ef21a10a595016da5b007c7d9cd9d0bffd1fcc61"}, + {file = "scikit_learn-1.7.2-cp314-cp314-win_amd64.whl", hash = "sha256:bb24510ed3f9f61476181e4db51ce801e2ba37541def12dc9333b946fc7a9cf8"}, + {file = "scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.22.0" +scipy = ">=1.8.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.5.0)", "memory_profiler (>=0.57.0)", "pandas (>=1.4.0)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.17.1)", "numpy (>=1.22.0)", "scipy (>=1.8.0)"] +docs = ["Pillow (>=8.4.0)", "matplotlib (>=3.5.0)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.4.0)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.19.0)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.17.1)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)", "towncrier (>=24.8.0)"] +examples = ["matplotlib (>=3.5.0)", "pandas (>=1.4.0)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.19.0)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.22.0)", "scipy (>=1.8.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==3.0.1)"] +tests = ["matplotlib (>=3.5.0)", "mypy (>=1.15)", "numpydoc (>=1.2.0)", "pandas (>=1.4.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.2.1)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.11.7)", "scikit-image (>=0.19.0)"] + [[package]] name = "scipy" version = "1.15.2" @@ -5243,6 +5311,35 @@ nativelib = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\"", "pywin32 ; s objc = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\""] win32 = ["pywin32 ; sys_platform == \"win32\""] +[[package]] +name = "sentence-transformers" +version = "5.1.2" +description = "Embeddings, Retrieval, and Reranking" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sentence_transformers-5.1.2-py3-none-any.whl", hash = "sha256:724ce0ea62200f413f1a5059712aff66495bc4e815a1493f7f9bca242414c333"}, + {file = "sentence_transformers-5.1.2.tar.gz", hash = "sha256:0f6c8bd916a78dc65b366feb8d22fd885efdb37432e7630020d113233af2b856"}, +] + +[package.dependencies] +huggingface-hub = ">=0.20.0" +Pillow = "*" +scikit-learn = "*" +scipy = "*" +torch = ">=1.11.0" +tqdm = "*" +transformers = ">=4.41.0,<5.0.0" +typing_extensions = ">=4.5.0" + +[package.extras] +dev = ["accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"] +onnx = ["optimum[onnxruntime] (>=1.23.1)"] +onnx-gpu = ["optimum[onnxruntime-gpu] (>=1.23.1)"] +openvino = ["optimum-intel[openvino] (>=1.20.0)"] +train = ["accelerate (>=0.20.3)", "datasets"] + [[package]] name = "setuptools" version = "78.1.0" @@ -5533,6 +5630,18 @@ files = [ {file = "tesserocr-2.7.1.tar.gz", hash = "sha256:3744c5c8bbabf18172849c7731be00dc2e5e44f8c556d37c850e788794ae0af4"}, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb"}, + {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"}, +] + [[package]] name = "tifffile" version = "2025.3.30" @@ -6305,4 +6414,4 @@ paddle = [] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "5e3ef145fd8c68bd9c68b6889b96b34e848ccf413861717e9a563dc2de0fe552" +content-hash = "44e11ba343de454ae8d136f03cfdaca2869e1587775106317485ec2f18ec0755" diff --git a/pyproject.toml b/pyproject.toml index 15bfb29..c00b631 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scaledp" -version = "0.2.5" +version = "0.3.0rc1" description = "ScaleDP is a library for processing documents and images using Apache Spark and LLMs" authors = ["Mykola Melnyk "] repository = "https://github.com/StabRise/scaledp" @@ -46,14 +46,14 @@ shapely = "^2.1.1" pyclipper = "^1.3.0.post6" onnxruntime = "1.22.0" opencv-python = "^4.12.0.88" - - +sentence-transformers = {version ="^5.1.2", optional = true} [tool.poetry.extras] ml = ["transformers", - #"torch", - #"torchvision" + "torch", + "torchvision", + "sentence-transformers" ] ocr = ["easyocr", "python-doctr", "surya-ocr"] llm = ["dspy"] @@ -88,6 +88,7 @@ ultralytics = "^8.3.40" pre-commit = "^3.7.1" ruff = "^0.5.0" craft-text-detector-updated = "^0.4.7" +sentence-transformers = "^5.1.2" [tool.poetry.group.dev.dependencies] diff --git a/scaledp/models/embeddings/BaseEmbeddings.py b/scaledp/models/embeddings/BaseEmbeddings.py new file mode 100644 index 0000000..49e8ea9 --- /dev/null +++ b/scaledp/models/embeddings/BaseEmbeddings.py @@ -0,0 +1,75 @@ +import json + +from pyspark.ml import Transformer +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable +from pyspark.sql.functions import lit, pandas_udf, udf + +from scaledp.params import ( + HasBatchSize, + HasColumnValidator, + HasDefaultEnum, + HasDevice, + HasInputCol, + HasKeepInputData, + HasModel, + HasNumPartitions, + HasOutputCol, + HasPageCol, + HasPartitionMap, + HasPathCol, +) +from scaledp.schemas.EmbeddingsOutput import EmbeddingsOutput + + +class BaseEmbeddings( + Transformer, + HasInputCol, + HasOutputCol, + HasKeepInputData, + HasDevice, + HasModel, + HasPathCol, + DefaultParamsReadable, + DefaultParamsWritable, + HasNumPartitions, + HasBatchSize, + HasPageCol, + HasColumnValidator, + HasDefaultEnum, + HasPartitionMap, +): + + def get_params(self): + return json.dumps({k.name: v for k, v in self.extractParamMap().items()}) + + def _transform(self, dataset): + params = self.get_params() + out_col = self.getOutputCol() + input_col = self.getInputCol() + if input_col not in dataset.columns: + raise ValueError(f"Column {input_col} not found in dataset") + in_col = self._validate(input_col, dataset) + + if not self.getPartitionMap(): + result = dataset.withColumn( + out_col, + udf(self.transform_udf, EmbeddingsOutput.get_schema())(in_col), + ) + else: + if self.getNumPartitions() > 0: + if self.getPageCol() in dataset.columns: + dataset = dataset.repartition(self.getPageCol()) + elif self.getPathCol() in dataset.columns: + dataset = dataset.repartition(self.getPathCol()) + dataset = dataset.coalesce(self.getNumPartitions()) + result = dataset.withColumn( + out_col, + pandas_udf(self.transform_udf_pandas, EmbeddingsOutput.get_schema())( + in_col, + lit(params), + ), + ) + + if not self.getKeepInputData(): + result = result.drop(in_col) + return result diff --git a/scaledp/models/embeddings/TextEmbeddings.py b/scaledp/models/embeddings/TextEmbeddings.py new file mode 100644 index 0000000..e5f8520 --- /dev/null +++ b/scaledp/models/embeddings/TextEmbeddings.py @@ -0,0 +1,75 @@ +import json +from types import MappingProxyType +from typing import Any + +import pandas as pd +from pyspark import keyword_only +from sentence_transformers import SentenceTransformer + +from scaledp.enums import Device +from scaledp.models.embeddings.BaseEmbeddings import BaseEmbeddings +from scaledp.schemas.EmbeddingsOutput import EmbeddingsOutput + + +class TextEmbeddings(BaseEmbeddings): + defaultParams = MappingProxyType( + { + "inputCol": "text", + "outputCol": "embeddings", + "keepInputData": True, + "model": "all-MiniLM-L6-v2", + "numPartitions": 1, + "partitionMap": False, + "device": Device.CPU, + "batchSize": 1, + "pageCol": "page", + "pathCol": "path", + }, + ) + + @keyword_only + def __init__(self, **kwargs: Any) -> None: + super(TextEmbeddings, self).__init__() + self._setDefault(**self.defaultParams) + self._set(**kwargs) + self._model = None + + def get_model(self): + if self._model is None: + self._model = SentenceTransformer(self.getModel()) + return self._model + + def transform_udf(self, text: str): + model = self.get_model() + embedding = model.encode( + text, + batch_size=self.getBatchSize(), + device=self.getSTDevice(), + ) + return EmbeddingsOutput( + path="memory", + data=embedding.tolist(), + type="text", + exception="", + ) + + @staticmethod + def transform_udf_pandas(texts: pd.Series, params: pd.Series) -> pd.DataFrame: + params = json.loads(params[0]) + model = SentenceTransformer(params["model"]) + embeddings = model.encode( + texts.tolist(), + batch_size=params["batchSize"], + device="cpu" if params["device"] == Device.CPU.value else "cuda", + ) + results = [] + for embedding in embeddings: + results.append( + EmbeddingsOutput( + path="memory", + data=embedding.tolist(), + type="text", + exception="", + ), + ) + return pd.DataFrame(results) diff --git a/scaledp/models/embeddings/__init__.py b/scaledp/models/embeddings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scaledp/params.py b/scaledp/params.py index 7019c4d..5a6c4ce 100644 --- a/scaledp/params.py +++ b/scaledp/params.py @@ -306,6 +306,15 @@ def getDevice(self): """ return self.getOrDefault(self.device) + def getSTDevice(self): + """ + Gets the value of device for SentenceTransformer. + """ + device = self.getOrDefault(self.device) + if device == -1: + return "cpu" + return f"cuda:{device}" + class HasBatchSize(Params): batchSize = Param( diff --git a/scaledp/schemas/EmbeddingsOutput.py b/scaledp/schemas/EmbeddingsOutput.py new file mode 100644 index 0000000..3e6abab --- /dev/null +++ b/scaledp/schemas/EmbeddingsOutput.py @@ -0,0 +1,21 @@ +# Description: Schema Embeddings Output +from dataclasses import dataclass +from typing import Optional + +from scaledp.utils.dataclass import map_dataclass_to_struct, register_type + + +@dataclass(order=True) +class EmbeddingsOutput: + path: Optional[str] + data: Optional[list[float]] + type: Optional[str] + exception: Optional[str] = "" + processing_time: Optional[float] = 0.0 + + @staticmethod + def get_schema(): + return map_dataclass_to_struct(EmbeddingsOutput) + + +register_type(EmbeddingsOutput, EmbeddingsOutput.get_schema) diff --git a/scaledp/utils/dataclass.py b/scaledp/utils/dataclass.py index 286ce24..495d671 100644 --- a/scaledp/utils/dataclass.py +++ b/scaledp/utils/dataclass.py @@ -5,6 +5,7 @@ from dataclasses import fields, is_dataclass from typing import Dict, Type, Union, get_type_hints +from pyspark.ml.linalg import Vector, VectorUDT from pyspark.sql.types import ( ArrayType, BinaryType, @@ -37,6 +38,7 @@ datetime.date: DateType, bool: BooleanType, BinaryT: BinaryType, + Vector: VectorUDT, } @@ -141,6 +143,11 @@ def get_spark_type( elem_type = py_type.__args__[0] return get_spark_type(elem_type) + # Handle list types recursively + if hasattr(py_type, "__origin__") and py_type.__origin__ is list: + elem_type = py_type.__args__[0] + return ArrayType(get_spark_type(elem_type)) + raise Exception(f"Type {py_type} is not supported by PySpark") diff --git a/tests/models/embeddings/__init__.py b/tests/models/embeddings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/embeddings/test_text_embeddings.py b/tests/models/embeddings/test_text_embeddings.py new file mode 100644 index 0000000..27a4941 --- /dev/null +++ b/tests/models/embeddings/test_text_embeddings.py @@ -0,0 +1,79 @@ +from pyspark.ml import PipelineModel + +from scaledp.enums import Device +from scaledp.models.embeddings.TextEmbeddings import TextEmbeddings + + +def test_text_embeddings_pipeline(text_df): + + # Initialize the TextEmbeddings stage + text_embeddings = TextEmbeddings( + model="all-MiniLM-L6-v2", + inputCol="value", + outputCol="embeddings", + device=Device.CPU.value, + batchSize=2, + ) + + # Create a pipeline with the TextEmbeddings stage + pipeline = PipelineModel(stages=[text_embeddings]) + + result_df = pipeline.transform(text_df) + + # Cache the result for performance + result = result_df.select("embeddings", "value").cache() + + # Collect the results + data = result.collect() + + # Check that exceptions are empty + assert all(row.embeddings.exception == "" for row in data) + + # Assert that there is at least one result + assert len(data) > 0 + + # Assert that the 'embeddings' field is present in the result + assert hasattr(data[0], "embeddings") + + # Verify the embeddings are not empty + for row in data: + assert row.embeddings.data is not None + assert len(row.embeddings.data) > 0 + + +def test_text_embeddings_pipeline_pandas(text_df): + + # Initialize the TextEmbeddings stage + text_embeddings = TextEmbeddings( + model="all-MiniLM-L6-v2", + inputCol="value", + outputCol="embeddings", + device=Device.CPU.value, + partitionMap=True, + batchSize=2, + ) + + # Create a pipeline with the TextEmbeddings stage + pipeline = PipelineModel(stages=[text_embeddings]) + + result_df = pipeline.transform(text_df) + + # Cache the result for performance + result = result_df.select("embeddings", "value").cache() + + # Collect the results + data = result.collect() + + # Check that exceptions are empty + assert all(row.embeddings.exception == "" for row in data) + + # Assert that there is at least one result + assert len(data) > 0 + + # Assert that the 'embeddings' field is present in the result + assert hasattr(data[0], "embeddings") + + # Verify the embeddings are not empty + for row in data: + assert row.embeddings.data is not None + assert len(row.embeddings.data) > 0 From 7ba9ecede61f9cd7c6190680462ef2b6b071515f Mon Sep 17 00:00:00 2001 From: mykola Date: Wed, 12 Nov 2025 08:33:24 +0300 Subject: [PATCH 2/3] Updated CHANGELOG.md --- CHANGELOG.md | 7 +++++++ docs/source/release_notes.md | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b143213..9a4d7d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## [unreleased] + +### 🚀 Features + +- Added TextEmbeddings transformer, for compute embedding using SentenceTransformers + + ## [0.2.5] - 10.11.2025 ### 🚀 Features diff --git a/docs/source/release_notes.md b/docs/source/release_notes.md index 98d3a5c..33d612f 100644 --- a/docs/source/release_notes.md +++ b/docs/source/release_notes.md @@ -4,6 +4,41 @@ Release Notes This document outlines the release notes for the ScaledP project. It includes information about new features, bug fixes, and other changes made in each version. +## [unreleased] + +### 🚀 Features + +- Added [TextEmbeddings](#TextEmbeddings) transformer, for compute embedding using SentenceTransformers + + +## [0.2.5] - 10.11.2025 + +### 🚀 Features + +- Added param 'returnEmpty' to [ImageCropBoxes](#ImageCropBoxes) for avoid to have exceptions if no boxes are found +- Added labels param to the [YoloOnnxDetector](#YoloOnnxDetector) +- Improve displaying labels in [ImageDrawBoxes](#ImageDrawBoxes) + +### 🧰 Maintenance +- Updated versions of dependencies (Pandas, Numpy, OpenCV) + +### 🐛 Bug Fixes + +- Fixed convert color schema in [YoloOnnxDetector](#YoloOnnxDetector) +- Fixed show utils on Google Colab +- Fixed imports of the DataFrame + +### 📘 Jupyter Notebooks + +- [YoloOnnxDetector.ipynb](https://github.com/StabRise/ScaleDP-Tutorials/blob/master/object-detection/1.YoloOnnxDetector.ipynb) +- [FaceDetection.ipynb](https://github.com/StabRise/ScaleDP-Tutorials/blob/master/object-detection/2.FaceDetection.ipynb) +- [SignatureDetection.ipynb](https://github.com/StabRise/ScaleDP-Tutorials/blob/master/object-detection/3.SignatureDetection.ipynb) + +### 📝 Blog Posts + +- [Running YOLO Models on Spark Using ScaleDP](https://stabrise.com/blog/running_yolo_on_spark_with_scaledp/) + + ## 0.2.4 - 02.11.2025 ### 🚀 Features From b5c9933775514e5fa7850ed5c6748be3ed958f52 Mon Sep 17 00:00:00 2001 From: mykola Date: Wed, 12 Nov 2025 08:46:04 +0300 Subject: [PATCH 3/3] Updated poetry.lock --- poetry.lock | 61 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7b93403..555769b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1693,7 +1693,7 @@ version = "1.5.2" description = "Lightweight pipelining with Python functions" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "test"] files = [ {file = "joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241"}, {file = "joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55"}, @@ -5180,7 +5180,7 @@ version = "1.7.2" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.10" -groups = ["main"] +groups = ["main", "test"] files = [ {file = "scikit_learn-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b33579c10a3081d076ab403df4a4190da4f4432d443521674637677dc91e61f"}, {file = "scikit_learn-1.7.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:36749fb62b3d961b1ce4fedf08fa57a1986cd409eff2d783bca5d4b9b5fce51c"}, @@ -5317,7 +5317,7 @@ version = "5.1.2" description = "Embeddings, Retrieval, and Reranking" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "test"] files = [ {file = "sentence_transformers-5.1.2-py3-none-any.whl", hash = "sha256:724ce0ea62200f413f1a5059712aff66495bc4e815a1493f7f9bca242414c333"}, {file = "sentence_transformers-5.1.2.tar.gz", hash = "sha256:0f6c8bd916a78dc65b366feb8d22fd885efdb37432e7630020d113233af2b856"}, @@ -5636,7 +5636,7 @@ version = "3.6.0" description = "threadpoolctl" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "test"] files = [ {file = "threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb"}, {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"}, @@ -5870,7 +5870,6 @@ description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" groups = ["main", "test"] -markers = "sys_platform == \"darwin\"" files = [ {file = "torchvision-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:54e8513099e6f586356c70f809d34f391af71ad182fe071cc328a28af2c40608"}, {file = "torchvision-0.19.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:20a1f5e02bfdad7714e55fa3fa698347c11d829fa65e11e5a84df07d93350eed"}, @@ -5893,6 +5892,7 @@ files = [ {file = "torchvision-0.19.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e328309b8670a2e889b2fe76a1c2744a099c11c984da9a822357bd9debd699a5"}, {file = "torchvision-0.19.1-cp39-cp39-win_amd64.whl", hash = "sha256:6616f12e00a22e7f3fedbd0fccb0804c05e8fe22871668f10eae65cf3f283614"}, ] +markers = {main = "(extra == \"ml\" or extra == \"ocr\") and sys_platform == \"darwin\"", test = "sys_platform == \"darwin\""} [package.dependencies] numpy = "*" @@ -5903,6 +5903,51 @@ torch = "2.4.1" gdown = ["gdown (>=4.7.3)"] scipy = ["scipy"] +[[package]] +name = "torchvision" +version = "0.21.0" +description = "image and video datasets and models for torch deep learning" +optional = false +python-versions = ">=3.9" +groups = ["main", "test"] +files = [ + {file = "torchvision-0.21.0-1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5568c5a1ff1b2ec33127b629403adb530fab81378d9018ca4ed6508293f76e2b"}, + {file = "torchvision-0.21.0-1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ff96666b94a55e802ea6796cabe788541719e6f4905fc59c380fed3517b6a64d"}, + {file = "torchvision-0.21.0-1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:ffa2a16499508fe6798323e455f312c7c55f2a88901c9a7c0fb1efa86cf7e327"}, + {file = "torchvision-0.21.0-1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:7e9e9afa150e40cd2a8f0701c43cb82a8d724f512896455c0918b987f94b84a4"}, + {file = "torchvision-0.21.0-1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:669575b290ec27304569e188a960d12b907d5173f9cd65e86621d34c4e5b6c30"}, + {file = "torchvision-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:044ea420b8c6c3162a234cada8e2025b9076fa82504758cd11ec5d0f8cd9fa37"}, + {file = "torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:b0c0b264b89ab572888244f2e0bad5b7eaf5b696068fc0b93e96f7c3c198953f"}, + {file = "torchvision-0.21.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:54815e0a56dde95cc6ec952577f67e0dc151eadd928e8d9f6a7f821d69a4a734"}, + {file = "torchvision-0.21.0-cp310-cp310-win_amd64.whl", hash = "sha256:abbf1d7b9d52c00d2af4afa8dac1fb3e2356f662a4566bd98dfaaa3634f4eb34"}, + {file = "torchvision-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:110d115333524d60e9e474d53c7d20f096dbd8a080232f88dddb90566f90064c"}, + {file = "torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3891cd086c5071bda6b4ee9d266bb2ac39c998c045c2ebcd1e818b8316fb5d41"}, + {file = "torchvision-0.21.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:54454923a50104c66a9ab6bd8b73a11c2fc218c964b1006d5d1fe5b442c3dcb6"}, + {file = "torchvision-0.21.0-cp311-cp311-win_amd64.whl", hash = "sha256:49bcfad8cfe2c27dee116c45d4f866d7974bcf14a5a9fbef893635deae322f2f"}, + {file = "torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:97a5814a93c793aaf0179cfc7f916024f4b63218929aee977b645633d074a49f"}, + {file = "torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b578bcad8a4083b40d34f689b19ca9f7c63e511758d806510ea03c29ac568f7b"}, + {file = "torchvision-0.21.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5083a5b1fec2351bf5ea9900a741d54086db75baec4b1d21e39451e00977f1b1"}, + {file = "torchvision-0.21.0-cp312-cp312-win_amd64.whl", hash = "sha256:6eb75d41e3bbfc2f7642d0abba9383cc9ae6c5a4ca8d6b00628c225e1eaa63b3"}, + {file = "torchvision-0.21.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:659b76c86757cb2ee4ca2db245e0740cfc3081fef46f0f1064d11adb4a8cee31"}, + {file = "torchvision-0.21.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:084ac3f5a1f50c70d630a488d19bf62f323018eae1b1c1232f2b7047d3a7b76d"}, + {file = "torchvision-0.21.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5045a3a5f21ec3eea6962fa5f2fa2d4283f854caec25ada493fcf4aab2925467"}, + {file = "torchvision-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:9147f5e096a9270684e3befdee350f3cacafd48e0c54ab195f45790a9c146d67"}, + {file = "torchvision-0.21.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5c22caeaae8b3c36d93459f1a5294e6f43306cff856ed243189a229331a404b4"}, + {file = "torchvision-0.21.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:e6572227228ec521618cea9ac3a368c45b7f96f1f8622dc9f1afe891c044051f"}, + {file = "torchvision-0.21.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6bdce3890fa949219de129e85e4f6d544598af3c073afe5c44e14aed15bdcbb2"}, + {file = "torchvision-0.21.0-cp39-cp39-win_amd64.whl", hash = "sha256:8c44b6924b530d0702e88ff383b65c4b34a0eaf666e8b399a73245574d546947"}, +] +markers = {main = "extra == \"ocr\" and sys_platform != \"darwin\"", test = "sys_platform != \"darwin\""} + +[package.dependencies] +numpy = "*" +pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" +torch = "2.6.0" + +[package.extras] +gdown = ["gdown (>=4.7.3)"] +scipy = ["scipy"] + [[package]] name = "torchvision" version = "0.21.0+cpu" @@ -5910,7 +5955,6 @@ description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.9" groups = ["main", "test"] -markers = "sys_platform != \"darwin\"" files = [ {file = "torchvision-0.21.0+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:4ed0a1be50676a7c589ba83b62c9dc0267a87e852b8cd9b7d6db27ab36c6d552"}, {file = "torchvision-0.21.0+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:554ca0f5948ac89911299f8bfb6f23936d867387ea213ab235adc2814b510d0c"}, @@ -5923,6 +5967,7 @@ files = [ {file = "torchvision-0.21.0+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:45736c703050019f158f34ab1d031a313fe91412aef00e3f0d242251ec32a7aa"}, {file = "torchvision-0.21.0+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:9f369668a2c08b085a8797ea830d62bc009d73d3775cfb6c721567a61d5bcfb9"}, ] +markers = {main = "(extra == \"ocr\" or extra == \"ml\") and sys_platform != \"darwin\"", test = "sys_platform != \"darwin\""} [package.dependencies] numpy = "*" @@ -6407,11 +6452,11 @@ files = [ [extras] llm = [] -ml = ["transformers"] +ml = ["sentence-transformers", "torch", "torch", "torchvision", "torchvision", "transformers"] ocr = ["easyocr", "python-doctr", "surya-ocr"] paddle = [] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "44e11ba343de454ae8d136f03cfdaca2869e1587775106317485ec2f18ec0755" +content-hash = "e3513c9d3fa60d18fc08b4c95ad59c7f682c02f30bfa1bf72df23b2214cb7536"