From 9aa94eea413e21f2f8ec7aeb696ad1653a7b6d0b Mon Sep 17 00:00:00 2001 From: Googler Date: Wed, 26 Mar 2025 04:50:16 -0700 Subject: [PATCH 01/14] Migrate lit_nlp to sklearn v1.6.1 PiperOrigin-RevId: 740720447 --- lit_nlp/components/curves_test.py | 46 +++++++++++++++++++++++++++---- pyproject.toml | 2 +- requirements.txt | 2 +- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/lit_nlp/components/curves_test.py b/lit_nlp/components/curves_test.py index cef377c4..1b7dd194 100644 --- a/lit_nlp/components/curves_test.py +++ b/lit_nlp/components/curves_test.py @@ -51,7 +51,7 @@ def input_spec(self) -> lit_types.Spec: def output_spec(self) -> lit_types.Spec: return { 'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'), - 'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label') + 'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'), } def predict_minibatch( @@ -64,10 +64,9 @@ def predict_example(ex: lit_types.JsonDict) -> tuple[float, float, float]: return TEST_DATA[x].prediction for example in inputs: - output.append({ - 'pred': predict_example(example), - 'aux_pred': [1 / 3, 1 / 3, 1 / 3] - }) + output.append( + {'pred': predict_example(example), 'aux_pred': [1 / 3, 1 / 3, 1 / 3]} + ) return output @@ -148,6 +147,43 @@ def test_model_output_is_missing_in_config(self): config={'Label': 'red'}, ) + @parameterized.named_parameters( + dict( + testcase_name='red', + label='red', + exp_roc=[(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)], + exp_pr=[(0.5, 0.5), (2 / 3, 1.0), (1.0, 0.5), (1.0, 0.0)], + ), + dict( + testcase_name='blue', + label='blue', + exp_roc=[(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)], + exp_pr=[ + (0.3333333333333333, 1.0), + (0.5, 1.0), + (1.0, 1.0), + (1.0, 0.0), + ], + ), + ) + def test_interpreter_honors_user_selected_label( + self, label: str, exp_roc: _Curve, exp_pr: _Curve + ): + """Tests a happy scenario when a user doesn't specify the class label.""" + curves_data = self.ci.run( + inputs=self.dataset.examples, + model=self.model, + dataset=self.dataset, + config={ + curves.TARGET_LABEL_KEY: label, + curves.TARGET_PREDICTION_KEY: 'pred', + }, + ) + self.assertIn(curves.ROC_DATA, curves_data) + self.assertIn(curves.PR_DATA, curves_data) + self.assertEqual(curves_data[curves.ROC_DATA], exp_roc) + self.assertEqual(curves_data[curves.PR_DATA], exp_pr) + def test_config_spec(self): """Tests that the interpreter config has correct fields of correct type.""" spec = self.ci.config_spec() diff --git a/pyproject.toml b/pyproject.toml index 2a27db38..66536476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "rouge-score>=0.1.2", "sacrebleu>=2.3.1", "saliency>=0.1.3", - "scikit-learn>=1.0.2", + "scikit-learn>=1.6.1", "scipy>=1.10.1", "shap>=0.42.0,<0.46.0", "six>=1.16.0", diff --git a/requirements.txt b/requirements.txt index 9c4707b0..caabea60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,7 @@ requests>=2.31.0 rouge-score>=0.1.2 sacrebleu>=2.3.1 saliency>=0.1.3 -scikit-learn>=1.0.2 +scikit-learn>=1.6.1 scipy>=1.10.1 shap>=0.42.0,<0.46.0 six>=1.16.0 From 6d24e2a36217b1dbccea3da07d8f2d6b5026a0fc Mon Sep 17 00:00:00 2001 From: Stephen Hicks Date: Wed, 2 Apr 2025 14:15:23 -0700 Subject: [PATCH 02/14] Automated Code Change PiperOrigin-RevId: 743275871 --- lit_nlp/client/services/group_service.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/lit_nlp/client/services/group_service.ts b/lit_nlp/client/services/group_service.ts index c85a2650..f7c562ff 100644 --- a/lit_nlp/client/services/group_service.ts +++ b/lit_nlp/client/services/group_service.ts @@ -499,6 +499,7 @@ export class GroupService extends LitService { getFeatureValForInput( bins: NumericFeatureBins, d: IndexedInput, feature: string): string | null { const isNumerical = this.numericalFeatureNames.includes(feature); + // @ts-ignore return isNumerical ? this.getNumericalBinForExample(bins, d, feature) : this.dataService.getVal(d.id, feature); } From 3339dc63ff2d0a8450dc53ba76692dc409767b8e Mon Sep 17 00:00:00 2001 From: Rasmi Elasmar Date: Fri, 18 Apr 2025 12:24:35 -0700 Subject: [PATCH 03/14] Disable pytype ImportError for keras. PiperOrigin-RevId: 749111685 --- lit_nlp/examples/prompt_debugging/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/examples/prompt_debugging/models.py b/lit_nlp/examples/prompt_debugging/models.py index ca9c797e..178b5409 100644 --- a/lit_nlp/examples/prompt_debugging/models.py +++ b/lit_nlp/examples/prompt_debugging/models.py @@ -30,7 +30,7 @@ def _initialize_modeling_environment( # NOTE: Imported here and not at the top of the file to avoid # initialization issues with the environment variables above. - import keras # pylint: disable=g-import-not-at-top + import keras # pylint: disable=g-import-not-at-top # pytype: disable=import-error keras.config.set_floatx(precision) elif dl_runtime == "torch": From a32d30c4370d29cde04b4608f1842ad8305563d6 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Tue, 29 Apr 2025 13:02:30 -0700 Subject: [PATCH 04/14] Automated Code Change PiperOrigin-RevId: 752850543 --- lit_nlp/lib/caching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lit_nlp/lib/caching.py b/lit_nlp/lib/caching.py index 169dfd6d..deb5a683 100644 --- a/lit_nlp/lib/caching.py +++ b/lit_nlp/lib/caching.py @@ -307,7 +307,7 @@ def predict(self, len(miss_idxs), len(cached_results)) else: # If all results were already cached, return them. - return cached_results + return cached_results # pytype: disable=bad-return-type with self._cache.get_pred_lock(input_keys): model_preds = list(self.wrapped.predict(progress_indicator(misses))) @@ -326,7 +326,7 @@ def predict(self, # Remove the prediction lock from the cache as the request is complete self._cache.delete_pred_lock(input_keys) - return cached_results + return cached_results # pytype: disable=bad-return-type def _get_results_from_cache(self, input_keys: list[CacheKey]): with self._cache.lock: From 151f83ea8a7399cf2f1dba3e7cce4fbe402e949a Mon Sep 17 00:00:00 2001 From: Fan Ye Date: Tue, 13 May 2025 10:07:34 -0700 Subject: [PATCH 05/14] Update lit_nlp ci workflow for github action. PiperOrigin-RevId: 758268118 --- .github/workflows/ci.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0432a655..b0ba5a98 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,8 +45,14 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + - name: Upgrade pip + run: python -m pip install --upgrade pip==23.2 - name: Install LIT package with testing dependencies run: python -m pip install -e '.[test]' + - name: Debug dependency tree + run: | + python -m pip install pipdeptree + pipdeptree | grep decorator -A 5 || true - name: Test LIT run: pytest -v - name: Setup Node ${{ matrix.node-version }} From 7c2e75415c2f8f0587a3635231ef19b1b94bd861 Mon Sep 17 00:00:00 2001 From: Fan Ye Date: Tue, 13 May 2025 19:20:49 -0700 Subject: [PATCH 06/14] LIT Dalle-Mini demo. PiperOrigin-RevId: 758468402 --- lit_nlp/examples/dalle_mini/README.md | 27 +++++ lit_nlp/examples/dalle_mini/data.py | 18 +++ lit_nlp/examples/dalle_mini/demo.py | 98 ++++++++++++++++ lit_nlp/examples/dalle_mini/model.py | 111 +++++++++++++++++++ lit_nlp/examples/dalle_mini/requirements.txt | 19 ++++ 5 files changed, 273 insertions(+) create mode 100644 lit_nlp/examples/dalle_mini/README.md create mode 100644 lit_nlp/examples/dalle_mini/data.py create mode 100644 lit_nlp/examples/dalle_mini/demo.py create mode 100644 lit_nlp/examples/dalle_mini/model.py create mode 100644 lit_nlp/examples/dalle_mini/requirements.txt diff --git a/lit_nlp/examples/dalle_mini/README.md b/lit_nlp/examples/dalle_mini/README.md new file mode 100644 index 00000000..957cb1cf --- /dev/null +++ b/lit_nlp/examples/dalle_mini/README.md @@ -0,0 +1,27 @@ +Dalle_Mini Demo for the Learning Interpretability Tool +======================================================= + +This demo showcases how LIT can be used in text-to-image generation mode. It is +based on the mini-dalle Mini model +(https://www.piwheels.org/project/dalle-mini/). + +You will need a standalone virtual environment for the Python libraries, which +you can set up using the following commands from the root of the LIT repo. + +```sh +# Create the virtual environment. You may want to use python3 or python3.10 +# depends on how many Python versions you have installed and their aliases. +python -m venv .dalle-mini +source .dalle-mini/bin/activate +# This requirements.txt file will also install the core LIT library deps. +pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt +# The LIT web app still needs to be built in the usual way. +(cd ./lit_nlp && yarn && yarn build) +``` + +Once your virtual environment is setup, you can launch the demo with the +following command. + +```sh +python -m lit_nlp.examples.dalle_mini.demo +``` \ No newline at end of file diff --git a/lit_nlp/examples/dalle_mini/data.py b/lit_nlp/examples/dalle_mini/data.py new file mode 100644 index 00000000..e54b1ca3 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/data.py @@ -0,0 +1,18 @@ +"""Data loaders for dalle-mini model.""" + +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import types as lit_types + + +class DallePrompts(lit_dataset.Dataset): + + def __init__(self, prompts: list[str]): + self.examples = [] + for prompt in prompts: + self.examples.append({"prompt": prompt}) + + def spec(self) -> lit_types.Spec: + return {"prompt": lit_types.TextSegment()} + + def __iter__(self): + return iter(self.examples) diff --git a/lit_nlp/examples/dalle_mini/demo.py b/lit_nlp/examples/dalle_mini/demo.py new file mode 100644 index 00000000..18cbc885 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/demo.py @@ -0,0 +1,98 @@ +r"""Example for dalle-mini demo model. + +To run locally with a small number of examples: + python -m lit_nlp.examples.dalle_mini.demo + + +Then navigate to localhost:5432 to access the demo UI. +""" + +from collections.abc import Sequence +import sys +from typing import Optional + +from absl import app +from absl import flags +from lit_nlp import app as lit_app +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.api import layout +from lit_nlp.examples.dalle_mini import data as dalle_data +from lit_nlp.examples.dalle_mini import model as dalle_model + + +# NOTE: additional flags defined in server_flags.py +_FLAGS = flags.FLAGS +_FLAGS.set_default("development_demo", True) +_FLAGS.set_default("default_layout", "DALLE_LAYOUT") + +_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.") + +_MODELS = (["dalle-mini"],) + +_CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"] + +# Custom frontend layout; see api/layout.py +_modules = layout.LitModuleName +_DALLE_LAYOUT = layout.LitCanonicalLayout( + upper={ + "Main": [ + _modules.DataTableModule, + _modules.DatapointEditorModule, + ] + }, + lower={ + "Predictions": [ + _modules.GeneratedImageModule, + _modules.GeneratedTextModule, + ], + }, + description="Custom layout for Text to Image models.", +) + + +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"DALLE_LAYOUT": _DALLE_LAYOUT} + + +def get_wsgi_app() -> Optional[dev_server.LitServerType]: + _FLAGS.set_default("server_type", "external") + _FLAGS.set_default("demo_mode", True) + # Parse flags without calling app.run(main), to avoid conflict with + # gunicorn command line flags. + unused = _FLAGS(sys.argv, known_only=True) + return main(unused) + + +def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + # Load models, according to the --models flag. + models = {} + + model_loaders: lit_app.ModelLoadersMap = {} + model_loaders["dalle-mini"] = ( + dalle_model.DalleMiniModel, + dalle_model.DalleMiniModel.init_spec(), + ) + + datasets = {"examples": dalle_data.DallePrompts(_CANNED_PROMPTS)} + dataset_loaders: lit_app.DatasetLoadersMap = {} + dataset_loaders["text_to_image"] = ( + dalle_data.DallePrompts, + dalle_data.DallePrompts.init_spec(), + ) + + lit_demo = dev_server.Server( + models=models, + model_loaders=model_loaders, + datasets=datasets, + dataset_loaders=dataset_loaders, + layouts=CUSTOM_LAYOUTS, + **server_flags.get_flags(), + ) + return lit_demo.serve() + + +if __name__ == "__main__": + app.run(main) diff --git a/lit_nlp/examples/dalle_mini/model.py b/lit_nlp/examples/dalle_mini/model.py new file mode 100644 index 00000000..487072d6 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/model.py @@ -0,0 +1,111 @@ +"""LIT wrappers for MiniDalleModel.""" + +from collections.abc import Iterable + +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.lib import image_utils +from min_dalle import MinDalle +import numpy as np +from PIL import Image +import torch + + +class DalleMiniModel(lit_model.Model): + """LIT model wrapper for Dalle-Mini Text-to-Image model. + + This wrapper simplifies the pipeline using Dalle-Mini for text-to-image + generation. + + + The basic flow within this model wrapper's predict() function is: + + + 1. Dalle-Mini processes the text prompt. + 2. Images are directly generated by Dalle-Mini. + """ + + def __init__( + self, + device: str = "cuda", # Use "cuda" for GPU or "cpu" for CPU + grid_size: int = 4, # each batch will generate grid_size**2 images + temperature: float = 0.5, + top_k: int = 256, + supercondition_factor: int = 32, + ): + super().__init__() + self.grid_size = grid_size + self.temperature = temperature + self.top_k = top_k + self.supercondition_factor = supercondition_factor + + # Load Dalle-Mini model + self.model = MinDalle( + models_root="./pretrained", + dtype=torch.float32, + device=device, + is_mega=True, + is_reusable=True, + ) + + def max_minibatch_size(self) -> int: + return 8 + + def predict( + self, inputs: Iterable[lit_types.JsonDict], **unused_kw + ) -> Iterable[lit_types.JsonDict]: + """Generate images based on the input prompts.""" + + def tensor_to_pil_image(tensor): + img_np = tensor.detach().cpu().numpy() + img_np = np.squeeze(img_np) + if img_np.ndim == 2: + img_np = np.stack([img_np] * 3, axis=-1) + elif img_np.ndim != 3 or img_np.shape[2] != 3: + raise ValueError( + f"Unexpected image shape: {img_np.shape}. Expected (H, W, 3)." + ) + + img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255 + img_np = img_np.clip(0, 255).astype(np.uint8) + return Image.fromarray(img_np) + + prompts = [ex["prompt"] for ex in inputs] + images = [] + for prompt in prompts: + # Generate images using the model + generated_images = self.model.generate_images( + text=prompt, + seed=-1, + grid_size=self.grid_size, + is_seamless=False, + temperature=self.temperature, + top_k=self.top_k, + supercondition_factor=self.supercondition_factor, + is_verbose=False, + ) + pil_images = [] + for img_tensor in generated_images: + pil_images.append(tensor_to_pil_image(img_tensor)) + images.append({ + "image": [ + image_utils.convert_pil_to_image_str(img) for img in pil_images + ], + "prompt": prompt, + }) + + return images + + def input_spec(self): + return { + "grid_size": lit_types.Scalar(), + "temperature": lit_types.Scalar(), + "top_k": lit_types.Scalar(), + "supercondition_factor": lit_types.Scalar(), + } + + def output_spec(self): + return { + "image": lit_types.ImageBytesList(), + "prompt": lit_types.TextSegment(), + } diff --git a/lit_nlp/examples/dalle_mini/requirements.txt b/lit_nlp/examples/dalle_mini/requirements.txt new file mode 100644 index 00000000..b5199a94 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/requirements.txt @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +-r ../../../requirements.txt + +# Dalle-Mini dependencies +min_dalle==0.4.11 From f2bde3db2dc567570bbd7691b21ff2100991d8b0 Mon Sep 17 00:00:00 2001 From: Fan Ye Date: Wed, 14 May 2025 10:49:39 -0700 Subject: [PATCH 07/14] Add a GCP text to image demo for LIT. PiperOrigin-RevId: 758748858 --- .../examples/gcp_text_to_image/datasets.py | 23 +++ lit_nlp/examples/gcp_text_to_image/demo.py | 129 ++++++++++++++++ lit_nlp/examples/gcp_text_to_image/models.py | 141 ++++++++++++++++++ .../examples/gcp_text_to_image/models_test.py | 83 +++++++++++ 4 files changed, 376 insertions(+) create mode 100644 lit_nlp/examples/gcp_text_to_image/datasets.py create mode 100644 lit_nlp/examples/gcp_text_to_image/demo.py create mode 100644 lit_nlp/examples/gcp_text_to_image/models.py create mode 100644 lit_nlp/examples/gcp_text_to_image/models_test.py diff --git a/lit_nlp/examples/gcp_text_to_image/datasets.py b/lit_nlp/examples/gcp_text_to_image/datasets.py new file mode 100644 index 00000000..431ae53c --- /dev/null +++ b/lit_nlp/examples/gcp_text_to_image/datasets.py @@ -0,0 +1,23 @@ +"""Data loaders for text to image models.""" + +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import types as lit_types + + +class TextToImageDataset(lit_dataset.Dataset): + """TextToImageDataset is a dataset that contains a list of prompts. + + It is used to generate images using the text to image models. + """ + + def __init__(self, prompts: list[str]): + self._examples = [] + for prompt in prompts: + self._examples.append({"prompt": prompt}) + + @classmethod + def init_spec(cls) -> lit_types.Spec: + return {"prompt": lit_types.TextSegment(required=True)} + + def spec(self) -> lit_types.Spec: + return {"prompt": lit_types.TextSegment()} diff --git a/lit_nlp/examples/gcp_text_to_image/demo.py b/lit_nlp/examples/gcp_text_to_image/demo.py new file mode 100644 index 00000000..a7e4cce9 --- /dev/null +++ b/lit_nlp/examples/gcp_text_to_image/demo.py @@ -0,0 +1,129 @@ +r"""A blank demo ready to load generative text to image models and datasets. + +To use with VertexAI Model Garden models, you must install the following packages: + pip install vertexai>=1.49.0 +To run the demo, you must set you GCP project location and project id. + +Currently, the demo only supports the image generation models in the Model +Garden. + +The following command can be used to run the demo: + blaze run -c opt examples/gcp_text_to_image:demo -- \ + --project_id=$GCP_PROJECT_ID \ + --project_location=$GCP_PROJECT_LOCATION \ + --alsologtostderr +Then navigate to localhost:5432 to access the demo UI. +""" + +from collections.abc import Sequence +import sys +from typing import Optional + +from absl import app +from absl import flags +from absl import logging +import google.auth +from google.cloud.aiplatform import vertexai +from lit_nlp import app as lit_app +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.api import layout +from lit_nlp.examples.gcp_text_to_image import datasets as gcp_text_to_image_datasets +from lit_nlp.examples.gcp_text_to_image import models as gcp_text_to_image_models + + +FLAGS = flags.FLAGS +# Define GCP project information and vertex AI API key. +LOCATION = flags.DEFINE_string( + 'project_location', + None, + 'Please enter your GCP project location', + required=True, +) +PROJECT_ID = flags.DEFINE_string( + 'project_id', + None, + 'Please enter your project id', + required=True, +) + +# Custom frontend layout; see api/layout.py +_modules = layout.LitModuleName +_IMAGE_LAYOUT = layout.LitCanonicalLayout( + upper={ + 'Main': [ + _modules.DataTableModule, + _modules.DatapointEditorModule, + ] + }, + lower={ + 'Predictions': [ + _modules.GeneratedImageModule, + _modules.GeneratedTextModule, + ], + }, + description='Custom layout for Text to Image models.', +) + + +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {'_IMAGE_LAYOUT': _IMAGE_LAYOUT} + +_CANNED_PROMPTS = ['I have a dream', 'I have a shiba dog named cola'] + + +def get_wsgi_app() -> Optional[dev_server.LitServerType]: + """Return WSGI app for container-hosted demos.""" + FLAGS.set_default('server_type', 'external') + FLAGS.set_default('demo_mode', True) + # Parse flags without calling app.run(main), to avoid conflict with + # gunicorn command line flags. + unused = flags.FLAGS(sys.argv, known_only=True) + if unused: + logging.info( + 'generateive_demo:get_wsgi_app() called with unused args: %s', unused + ) + return main([]) + + +def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + creds, _ = google.auth.default( + scopes=['https://www.googleapis.com/auth/cloud-platform'] + ) + creds = creds.with_quota_project(PROJECT_ID.value) + vertexai.init( + project=PROJECT_ID.value, + location=LOCATION.value, + credentials=creds, + ) + models = {} + model_loaders: lit_app.ModelLoadersMap = {} + model_loaders['text_to_image'] = ( + gcp_text_to_image_models.VertexModelGardenModel, + gcp_text_to_image_models.VertexModelGardenModel.init_spec(), + ) + + datasets = { + 'prompts': gcp_text_to_image_datasets.TextToImageDataset(_CANNED_PROMPTS) + } + dataset_loaders: lit_app.DatasetLoadersMap = {} + dataset_loaders['text_to_image'] = ( + gcp_text_to_image_datasets.TextToImageDataset, + gcp_text_to_image_datasets.TextToImageDataset.init_spec(), + ) + + lit_demo = dev_server.Server( + models=models, + model_loaders=model_loaders, + datasets=datasets, + dataset_loaders=dataset_loaders, + layout=layout.DEFAULT_LAYOUTS, + **server_flags.get_flags() + ) + return lit_demo.serve() + + +if __name__ == '__main__': + app.run(main) diff --git a/lit_nlp/examples/gcp_text_to_image/models.py b/lit_nlp/examples/gcp_text_to_image/models.py new file mode 100644 index 00000000..ef99946b --- /dev/null +++ b/lit_nlp/examples/gcp_text_to_image/models.py @@ -0,0 +1,141 @@ +"""Model Wrapper for generative models.""" + +from collections.abc import Iterable +import io +import logging +import time +from typing import Literal, Optional, Union +from vertexai import vision_models +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.lib import image_utils +from PIL import Image + +_MAX_NUM_RETRIES = 5 + +_DEFAULT_CANDIDATE_COUNT = 1 + +_DEFAULT_MAX_OUTPUT_TOKENS = 256 + +_IMAGE_PREFIX = 'data:image/png;base64,' + + +class VertexModelGardenModel(lit_model.BatchedRemoteModel): + """VertexModelGardenModel is a wrapper for Vertex AI Model Garden model. + + Attributes: + model_name: The name of the model to load. + max_concurrent_requests: The maximum number of concurrent requests to the + model. + max_qps: The maximum number of queries per second to the model. + temperature: The temperature to use for the model. + candidate_count: The number of candidates to generate. + max_output_tokens: The maximum number of tokens to generate. + + Please note the model will predict all examples at a fixed temperature. + """ + + def __init__( + self, + model_name: str = 'imagen-3.0-generate-002', + max_concurrent_requests: int = 4, + max_qps: Union[int, float] = 25, + aspect_ratio: Optional[ + Literal['16:9', '1:1', '3:4', '4:3', '9:16'] + ] = None, + width: int = 256, + height: int = 256, + ): + super().__init__(max_concurrent_requests, max_qps) + # Connect to the remote model. + self._model = vision_models.ImageGenerationModel.from_pretrained(model_name) + self._aspect_ratio = aspect_ratio + self._width = width + self._height = height + + def query_model(self, prompt: str, **unused_kw) -> list[lit_types.JsonDict]: + num_attempts = 0 + predictions = None + exception = None + width = self._width + height = self._height + + while num_attempts < _MAX_NUM_RETRIES and predictions is None: + num_attempts += 1 + + try: + predictions = self._model.generate_images( + prompt=prompt, + aspect_ratio=self._aspect_ratio, + ) + except Exception as e: # pylint: disable=broad-except + wait_time = 2**num_attempts + exception = e + logging.warning('Waiting %ds to retry... (%s)', wait_time, e) + time.sleep(2**num_attempts) + + if predictions is None: + raise ValueError( + f'Failed to get predictions. ({exception})' + ) from exception + + if not isinstance(predictions, Iterable): + raise ValueError(f'Predictions is not an Iterable: {type(predictions)}') + + images = [] + for image_ in predictions.images: + pil_img = Image.open(io.BytesIO(getattr(image_, '_image_bytes'))) + pil_img = pil_img.resize((width, height)) + images.append(image_utils.convert_pil_to_image_str(pil_img)) + + return images + + def predict_minibatch( + self, inputs: list[lit_types.JsonDict] + ) -> list[lit_types.JsonDict]: + """The model can generate up to 8 images per run, but LIT may only show one due to frontend limitations. + + In MinDalle demos, the grid_size parameter controls layout—for example, + grid_size=2 creates a 2x2 grid of sub-images, rendered as a single final + image. That’s why only one image might appear even if multiple are + generated. + + Args: + inputs: A list of input dictionaries, each containing a 'prompt'. + + Returns: + A list of dictionaries, each containing the generated 'image' and the + original 'prompt'. + """ + results = [] + for inp in inputs: + prompt = inp['prompt'] + b64_strs = self.query_model(prompt) + if not b64_strs: + raise ValueError(f'No images generated for prompt: {prompt}') + results.append({ + 'image': b64_strs[0], + 'prompt': prompt, + }) + return results + + @classmethod + def init_spec(cls) -> lit_types.Spec: + return { + 'model_name': lit_types.String( + default='imagen-3.0-generate-002', required=True + ), + 'aspect_ratio': lit_types.String(default='1:1', required=False), + 'width': lit_types.Integer(default=256, required=False), + 'height': lit_types.Integer(default=256, required=False), + } + + def input_spec(self) -> lit_types.Spec: + return { + 'prompt': lit_types.TextSegment(), + } + + def output_spec(self): + return { + 'image': lit_types.ImageBytesList(), + } diff --git a/lit_nlp/examples/gcp_text_to_image/models_test.py b/lit_nlp/examples/gcp_text_to_image/models_test.py new file mode 100644 index 00000000..e3ac4558 --- /dev/null +++ b/lit_nlp/examples/gcp_text_to_image/models_test.py @@ -0,0 +1,83 @@ +import base64 +from unittest import mock +from absl.testing import absltest +from vertexai import vision_models +from lit_nlp.examples.gcp_text_to_image import models + + +class MockModel: + + def __init__( + self, images=None, raise_exception=False, sample_image_bytes=None + ): + self.images = images if images else [] + self.raise_exception = raise_exception + self.call_count = 0 + self.sample_image_bytes = sample_image_bytes + + def generate_images(self, prompt, aspect_ratio=None): + _, _ = prompt, aspect_ratio + self.call_count += 1 + if self.raise_exception: + raise ValueError("Mock Model Error") + + if self.sample_image_bytes: + # Create a mock GeneratedImage instance, passing image_bytes + mock_image = mock.create_autospec( + vision_models.GeneratedImage, instance=True + ) + mock_image._image_bytes = self.sample_image_bytes + mock_response = vision_models.ImageGenerationResponse(images=[mock_image]) + return mock_response + + return vision_models.ImageGenerationResponse(images=[]) + + +class ModelsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Create a sample image for testing + png_base64 = b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMBAKh72VgAAAAASUVORK5CYII=" + self.sample_image_bytes = base64.b64decode(png_base64) + + @mock.patch( + "vertexai.vision_models.ImageGenerationModel.from_pretrained", + ) + @mock.patch("PIL.Image.open") + def test_query_model(self, mock_image_open, mock_from_pretrained): + # Create a MockModel instance + mock_model = MockModel( + sample_image_bytes=self.sample_image_bytes, + ) + # Configure mock_from_pretrained to return the mock_model + mock_from_pretrained.return_value = mock_model + + model = models.VertexModelGardenModel(model_name="test_model_name") + mock_image = mock.Mock() + + mock_image.resize.return_value = mock_image + mock_image_open.return_value = mock_image + + output = model.predict_minibatch( + inputs=[{"prompt": "I say yes you say no"}] + ) + result = list(output) + + self.assertLen(result, 1) + self.assertIn("image", result[0]) + self.assertIn("prompt", result[0]) + self.assertEqual(result[0]["prompt"], "I say yes you say no") + + # Validate that the image is a base64 string + self.assertTrue(result[0]["image"].startswith("data:image/png")) + self.assertIsInstance(result[0]["image"], str) + + mock_from_pretrained.assert_called_once_with("test_model_name") + + # Assert that mock_generate_content was called + self.assertEqual(mock_model.call_count, 1) + + +if __name__ == "__main__": + absltest.main() From ee87faf790fb4a40aee21f960e02b7c2bcbc8752 Mon Sep 17 00:00:00 2001 From: Fan Ye Date: Wed, 14 May 2025 11:22:31 -0700 Subject: [PATCH 08/14] Test and fix errors in dalle_mini examples. Now it should works fine locally. PiperOrigin-RevId: 758762358 --- lit_nlp/examples/dalle_mini/data.py | 15 +++++--- lit_nlp/examples/dalle_mini/demo.py | 36 ++++++++++++++++++-- lit_nlp/examples/dalle_mini/model.py | 7 +--- lit_nlp/examples/dalle_mini/requirements.txt | 2 ++ 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/lit_nlp/examples/dalle_mini/data.py b/lit_nlp/examples/dalle_mini/data.py index e54b1ca3..850d7e81 100644 --- a/lit_nlp/examples/dalle_mini/data.py +++ b/lit_nlp/examples/dalle_mini/data.py @@ -5,14 +5,19 @@ class DallePrompts(lit_dataset.Dataset): + """DallePrompts is a dataset that contains a list of prompts. + + It is used to generate images using the dalle-mini model. + """ def __init__(self, prompts: list[str]): - self.examples = [] + self._examples = [] for prompt in prompts: - self.examples.append({"prompt": prompt}) + self._examples.append({"prompt": prompt}) + + @classmethod + def init_spec(cls) -> lit_types.Spec: + return {"prompt": lit_types.TextSegment(required=True)} def spec(self) -> lit_types.Spec: return {"prompt": lit_types.TextSegment()} - - def __iter__(self): - return iter(self.examples) diff --git a/lit_nlp/examples/dalle_mini/demo.py b/lit_nlp/examples/dalle_mini/demo.py index 18cbc885..3c043aff 100644 --- a/lit_nlp/examples/dalle_mini/demo.py +++ b/lit_nlp/examples/dalle_mini/demo.py @@ -1,8 +1,42 @@ r"""Example for dalle-mini demo model. +First run following command to install required packages: + pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt + To run locally with a small number of examples: python -m lit_nlp.examples.dalle_mini.demo +By default, this module uses the "cuda" device for image generation. +The `requirements.txt` file installs a CUDA-enabled version of PyTorch for GPU +acceleration. + +If you are running on a machine without a compatible GPU or CUDA drivers, +you must switch the device to "cpu" and reinstall the CPU-only version of +PyTorch. + +Usage: + - Default: device="cuda" + - On CPU-only machines: + 1. Set device="cpu" during model initialization + 2. Uninstall the CUDA version of PyTorch: + pip uninstall torch + 3. Install the CPU-only version: + pip install torch==2.1.2+cpu --extra-index-url + https://download.pytorch.org/whl/cpu + +Example: + >>> model = MinDalle(..., device="cpu") + +Check CUDA availability: + >>> import torch + >>> torch.cuda.is_available() + False # if no GPU support is present + +Error Handling: + - If CUDA is selected but unsupported, you will see: + AssertionError: Torch not compiled with CUDA enabled + - To fix this, either install the correct CUDA-enabled PyTorch or switch to + CPU mode. Then navigate to localhost:5432 to access the demo UI. """ @@ -26,8 +60,6 @@ _FLAGS.set_default("development_demo", True) _FLAGS.set_default("default_layout", "DALLE_LAYOUT") -_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.") - _MODELS = (["dalle-mini"],) _CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"] diff --git a/lit_nlp/examples/dalle_mini/model.py b/lit_nlp/examples/dalle_mini/model.py index 487072d6..78492aae 100644 --- a/lit_nlp/examples/dalle_mini/model.py +++ b/lit_nlp/examples/dalle_mini/model.py @@ -97,12 +97,7 @@ def tensor_to_pil_image(tensor): return images def input_spec(self): - return { - "grid_size": lit_types.Scalar(), - "temperature": lit_types.Scalar(), - "top_k": lit_types.Scalar(), - "supercondition_factor": lit_types.Scalar(), - } + return {"prompt": lit_types.TextSegment()} def output_spec(self): return { diff --git a/lit_nlp/examples/dalle_mini/requirements.txt b/lit_nlp/examples/dalle_mini/requirements.txt index b5199a94..184c04cf 100644 --- a/lit_nlp/examples/dalle_mini/requirements.txt +++ b/lit_nlp/examples/dalle_mini/requirements.txt @@ -17,3 +17,5 @@ # Dalle-Mini dependencies min_dalle==0.4.11 +torch==2.1.2+cu118 +--extra-index-url https://download.pytorch.org/whl/cu118 \ No newline at end of file From f34ab60bead4a90259e74dc3cfe6d6c5fb46e2ab Mon Sep 17 00:00:00 2001 From: Fan Ye Date: Fri, 6 Jun 2025 10:52:42 -0700 Subject: [PATCH 09/14] Adding text-to-image demo in demos doc. PiperOrigin-RevId: 768138158 --- website/sphinx_src/demos.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/website/sphinx_src/demos.md b/website/sphinx_src/demos.md index 448c0d23..5be0910b 100644 --- a/website/sphinx_src/demos.md +++ b/website/sphinx_src/demos.md @@ -83,6 +83,17 @@ Generative AI Toolkit. -------------------------------------------------------------------------------- +## Text To Image Demo + +### min(DALL·E) + +**Code:** [examples/dalle_mini/demo.py](https://github.com/PAIR-code/lit/blob/main/lit_nlp/examples/dalle_mini/demo.py) + +* Support text to image generation in LIT using + [min(DALL·E)](https://github.com/kuprel/min-dalle) model. + +-------------------------------------------------------------------------------- + ## Multimodal ### Tabular Data: Penguin Classification From 4542238f11a9b16be00bf61f669e318ed9d453fc Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 17 Jun 2025 23:48:57 -0700 Subject: [PATCH 10/14] Automated Code Change PiperOrigin-RevId: 772786015 --- .../client/modules/annotated_text_module.ts | 22 +++++++++++++------ .../modules/feature_attribution_module.ts | 10 ++++++--- lit_nlp/client/modules/pdp_module.ts | 9 ++++++-- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/lit_nlp/client/modules/annotated_text_module.ts b/lit_nlp/client/modules/annotated_text_module.ts index 3009b954..d01f04e2 100644 --- a/lit_nlp/client/modules/annotated_text_module.ts +++ b/lit_nlp/client/modules/annotated_text_module.ts @@ -19,12 +19,15 @@ import {customElement} from 'lit/decorators.js'; import {makeObservable, observable} from 'mobx'; import {LitModule} from '../core/lit_module'; -import {type AnnotationGroups, TextSegments} from '../elements/annotated_text_vis'; +import {type AnnotationGroups, type AnnotationSpec, type SegmentSpec, TextSegments} from '../elements/annotated_text_vis'; import {MultiSegmentAnnotations, TextSegment} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {type IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {doesOutputSpecContain, filterToKeys, findSpecKeys} from '../lib/utils'; +// This should be removed. +type AnyDuringMigration = any; + /** LIT module for model output. */ @customElement('annotated-text-gold-module') export class AnnotatedTextGoldModule extends LitModule { @@ -53,13 +56,15 @@ export class AnnotatedTextGoldModule extends LitModule { // Text segment fields const segmentNames = findSpecKeys(dataSpec, TextSegment); const segments: TextSegments = filterToKeys(input.data, segmentNames); - const segmentSpec = filterToKeys(dataSpec, segmentNames); + const segmentSpec: SegmentSpec = + filterToKeys(dataSpec, segmentNames) as AnyDuringMigration; // Annotation fields const annotationNames = findSpecKeys(dataSpec, MultiSegmentAnnotations); const annotations: AnnotationGroups = filterToKeys(input.data, annotationNames); - const annotationSpec = filterToKeys(dataSpec, annotationNames); + const annotationSpec: AnnotationSpec = + filterToKeys(dataSpec, annotationNames) as AnyDuringMigration; // If more than one model is selected, AnnotatedTextModule will be offset // vertically due to the model name header, while this one won't be. @@ -149,12 +154,15 @@ export class AnnotatedTextModule extends LitModule { findSpecKeys(this.appState.currentDatasetSpec, TextSegment); const segments: TextSegments = filterToKeys(this.currentData.data, segmentNames); - const segmentSpec = - filterToKeys(this.appState.currentDatasetSpec, segmentNames); + const segmentSpec: SegmentSpec = + filterToKeys(this.appState.currentDatasetSpec, segmentNames) as + AnyDuringMigration; const outputSpec = this.appState.getModelSpec(this.model).output; - const annotationSpec = filterToKeys( - outputSpec, findSpecKeys(outputSpec, MultiSegmentAnnotations)); + const annotationSpec: AnnotationSpec = + filterToKeys( + outputSpec, findSpecKeys(outputSpec, MultiSegmentAnnotations)) as + AnyDuringMigration; // clang-format off return html` this.colorMap.bgCmap(val); + const scale: D3Scale = + ((val: number) => this.colorMap.bgCmap(val)) as AnyDuringMigration; scale.domain = () => this.colorMap.colorScale.domain(); // clang-format off diff --git a/lit_nlp/client/modules/pdp_module.ts b/lit_nlp/client/modules/pdp_module.ts index 107bf866..fee10da4 100644 --- a/lit_nlp/client/modules/pdp_module.ts +++ b/lit_nlp/client/modules/pdp_module.ts @@ -44,6 +44,9 @@ interface AllPdpInfo { // Data for bar or line charts. type ChartInfo = Map; +// This should be removed. +type AnyDuringMigration = any; + /** * A LIT module that renders regression results. */ @@ -172,14 +175,16 @@ export class PdpModule extends LitModule { const yRange = isClassification ? [0, 1] : []; const renderChart = (chartData: ChartInfo) => { if (isNumeric) { + const chartMap: Map = chartData as AnyDuringMigration; return html` + .scores=${chartMap} .yScale=${yRange}> `; } else { + const chartMap: Map = chartData as AnyDuringMigration; return html` + .scores=${chartMap} .yScale=${yRange}> `; } From 5b760bb2970a66b3d0e85a7ca6359f4fb9eaf2f4 Mon Sep 17 00:00:00 2001 From: Googler Date: Fri, 15 Aug 2025 05:15:16 -0700 Subject: [PATCH 11/14] Automated Code Change PiperOrigin-RevId: 795438315 --- lit_nlp/client/core/slice_module.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/client/core/slice_module.ts b/lit_nlp/client/core/slice_module.ts index 674e652a..43b454b5 100644 --- a/lit_nlp/client/core/slice_module.ts +++ b/lit_nlp/client/core/slice_module.ts @@ -160,7 +160,7 @@ export class SliceModule extends LitModule { // clang-format off return html`
- {onKeyUp(e);}}/>