Skip to content

Versioning issue between jax and diffusion packages #15

@MartaTintore

Description

@MartaTintore

Hi,

I have been trying to run the colab connected to a Python3 T4 GPU runtime and I am seeing incompatibility issues between the different package versions. I f I run the colab as it is:

# Installing PyTorch with CUDA support (matching Colab's CUDA version, usually the latest supported by PyTorch)
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

# Installing Diffusers and Transformers
!pip install diffusers==0.8.0 transformers

# Installing essential and commonly used libraries
!pip install numpy pandas scipy scikit-learn matplotlib opencv-python-headless

# Additional utilities that might be useful
!pip install tqdm requests pillow

!pip install "jax[cuda12_local]==0.5.3" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

!pip install wandb

!pip install accelerate

then, I see the following during installation:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.89 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
chex 0.1.89 requires jaxlib>=0.4.27, but you have jaxlib 0.4.23+cuda12.cudnn89 which is incompatible.
optax 0.2.4 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
optax 0.2.4 requires jaxlib>=0.4.27, but you have jaxlib 0.4.23+cuda12.cudnn89 which is incompatible.
flax 0.10.4 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
orbax-checkpoint 0.11.10 requires jax>=0.5.0, but you have jax 0.4.23 which is incompatible.

and when trying to import some of the diffusion modules:

from diffusers import StableDiffusionPipeline, DDIMScheduler

I see the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-5-5f3e3eccac57>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from diffusers import StableDiffusionPipeline, DDIMScheduler

12 frames
[/usr/local/lib/python3.11/dist-packages/diffusers/__init__.py](https://localhost:8080/#) in <module>
     19 if is_torch_available():
     20     from .modeling_utils import ModelMixin
---> 21     from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
     22     from .optimization import (
     23         get_constant_schedule,

[/usr/local/lib/python3.11/dist-packages/diffusers/models/__init__.py](https://localhost:8080/#) in <module>
     24 
     25 if is_flax_available():
---> 26     from .unet_2d_condition_flax import FlaxUNet2DConditionModel
     27     from .vae_flax import FlaxAutoencoderKL

[/usr/local/lib/python3.11/dist-packages/diffusers/models/unet_2d_condition_flax.py](https://localhost:8080/#) in <module>
     14 from typing import Tuple, Union
     15 
---> 16 import flax
     17 import flax.linen as nn
     18 import jax

[/usr/local/lib/python3.11/dist-packages/flax/__init__.py](https://localhost:8080/#) in <module>
     22 del configurations
     23 
---> 24 from flax import core
     25 from flax import jax_utils
     26 from flax import linen

[/usr/local/lib/python3.11/dist-packages/flax/core/__init__.py](https://localhost:8080/#) in <module>
     22     unfreeze as unfreeze,
     23 )
---> 24 from .lift import (
     25     custom_vjp as custom_vjp,
     26     jit as jit,

[/usr/local/lib/python3.11/dist-packages/flax/core/lift.py](https://localhost:8080/#) in <module>
     25 
     26 from flax import traceback_util
---> 27 from flax import traverse_util
     28 from flax.typing import (
     29     In,

[/usr/local/lib/python3.11/dist-packages/flax/traverse_util.py](https://localhost:8080/#) in <module>
     64 
     65 import flax
---> 66 from flax.core.scope import VariableDict
     67 from flax.typing import PathParts
     68 

[/usr/local/lib/python3.11/dist-packages/flax/core/scope.py](https://localhost:8080/#) in <module>
     53 )
     54 
---> 55 from . import meta, partial_eval, tracers
     56 from .frozen_dict import FrozenDict, freeze, unfreeze
     57 

[/usr/local/lib/python3.11/dist-packages/flax/core/meta.py](https://localhost:8080/#) in <module>
    186 
    187 
--> 188 class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
    189   """Wrapper for partitioning metadata.
    190 

/usr/lib/python3.11/abc.py in __new__(mcls, name, bases, namespace, **kwargs)

[/usr/local/lib/python3.11/dist-packages/flax/struct.py](https://localhost:8080/#) in __init_subclass__(cls, **kwargs)
    233 
    234   def __init_subclass__(cls, **kwargs):
--> 235     dataclass(cls, **kwargs)  # pytype: disable=wrong-arg-types
    236 
    237   def __init__(self, *args, **kwargs):

[/usr/local/lib/python3.11/dist-packages/flax/struct.py](https://localhost:8080/#) in dataclass(clz, **kwargs)
    148   data_clz.replace = replace
    149 
--> 150   jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields)
    151 
    152   def to_state_dict(x):

[/usr/local/lib/python3.11/dist-packages/jax/_src/deprecations.py](https://localhost:8080/#) in getattr(name)
     51       warnings.warn(message, DeprecationWarning, stacklevel=2)
     52       return fn
---> 53     raise AttributeError(f"module {module!r} has no attribute {name!r}")
     54 
     55   return getattr

AttributeError: module 'jax.tree_util' has no attribute 'register_dataclass'

I have tried upgrading jax to jax-0.5.3 and jaxlib-0.5.3, but then I see this error instead:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
[<ipython-input-2-5f3e3eccac57>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from diffusers import StableDiffusionPipeline, DDIMScheduler

2 frames
[/usr/local/lib/python3.11/dist-packages/diffusers/__init__.py](https://localhost:8080/#) in <module>
     29         get_scheduler,
     30     )
---> 31     from .pipeline_utils import DiffusionPipeline
     32     from .pipelines import (
     33         DanceDiffusionPipeline,

[/usr/local/lib/python3.11/dist-packages/diffusers/pipeline_utils.py](https://localhost:8080/#) in <module>
     33 
     34 from .configuration_utils import ConfigMixin
---> 35 from .dynamic_modules_utils import get_class_from_dynamic_module
     36 from .hub_utils import http_user_agent
     37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT

[/usr/local/lib/python3.11/dist-packages/diffusers/dynamic_modules_utils.py](https://localhost:8080/#) in <module>
     24 from typing import Dict, Optional, Union
     25 
---> 26 from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
     27 
     28 from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging

ImportError: cannot import name 'cached_download' from 'huggingface_hub' (/usr/local/lib/python3.11/dist-packages/huggingface_hub/__init__.py)

I also tried upgrading jax and diffusers (to diffusers-0.32.2) but then in the last cells I encounter:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-12-1772ee7a7d70>](https://localhost:8080/#) in <cell line: 0>()
----> 1 embedding = optimize_embedding(
      2         ldm,
      3         num_steps=num_optimization_steps,
      4         batch_size=batch_size,
      5         top_k = num_keypoints,

10 frames
[<ipython-input-11-aead50517aea>](https://localhost:8080/#) in optimize_embedding(ldm, context, device, num_steps, from_where, upsample_res, layers, lr, noise_level, num_tokens, top_k, augment_degrees, augment_scale, augment_translate, dataset_loc, sigma, sharpening_loss_weight, equivariance_attn_loss_weight, batch_size, num_gpus, max_len, min_dist, furthest_point_num_samples, controllers, validation, num_subjects)
     76         image = mini_batch["img"]
     77 
---> 78         attn_maps = run_and_find_attn(
     79             ldm,
     80             image,

[<ipython-input-10-d782888e078d>](https://localhost:8080/#) in run_and_find_attn(ldm, image, context, noise_level, device, from_where, layers, upsample_res, indices, controllers)
    113     controllers=None,
    114 ):
--> 115     _, _ = find_pred_noise(
    116         ldm,
    117         image,

[<ipython-input-10-d782888e078d>](https://localhost:8080/#) in find_pred_noise(ldm, image, context, noise_level, device)
     41     # import ipdb; ipdb.set_trace()
     42 
---> 43     pred_noise = ldm.unet(noisy_image,
     44                           ldm.scheduler.timesteps[noise_level].repeat(noisy_image.shape[0]),
     45                           context.repeat(noisy_image.shape[0], 1, 1))["sample"]

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

[/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/data_parallel.py](https://localhost:8080/#) in forward(self, *inputs, **kwargs)
    189 
    190             if len(self.device_ids) == 1:
--> 191                 return self.module(*inputs[0], **module_kwargs[0])
    192             replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
    193             outputs = self.parallel_apply(replicas, inputs, module_kwargs)

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1843 
   1844         try:
-> 1845             return inner()
   1846         except Exception:
   1847             # run always called hooks if they have not already been run

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in inner()
   1780                                 )
   1781                     else:
-> 1782                         args_result = hook(self, args)
   1783                         if args_result is not None:
   1784                             if not isinstance(args_result, tuple):

[<ipython-input-6-72ab3767eb92>](https://localhost:8080/#) in hook_fn(module, input)
    200         _device = input[0].device
    201         # if device not in patched_devices:
--> 202         register_attention_control(module, controllers[_device], feature_upsample_res=feature_upsample_res)
    203         # patched_devices.add(device)
    204 

[<ipython-input-6-72ab3767eb92>](https://localhost:8080/#) in register_attention_control(model, controller, feature_upsample_res)
    160 
    161     # create assertion with message
--> 162     assert cross_att_count != 0, "No cross attention layers found in the model. Please check to make sure you're using diffusers==0.8.0."
    163 
    164 def load_ldm(device, type="CompVis/stable-diffusion-v1-4", feature_upsample_res=256):

AssertionError: No cross attention layers found in the model. Please check to make sure you're using diffusers==0.8.0.

Can you please recommend which versions to use for each package? Thanks!!!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions