-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
Hi,
I have been trying to run the colab connected to a Python3 T4 GPU runtime and I am seeing incompatibility issues between the different package versions. I f I run the colab as it is:
# Installing PyTorch with CUDA support (matching Colab's CUDA version, usually the latest supported by PyTorch)
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# Installing Diffusers and Transformers
!pip install diffusers==0.8.0 transformers
# Installing essential and commonly used libraries
!pip install numpy pandas scipy scikit-learn matplotlib opencv-python-headless
# Additional utilities that might be useful
!pip install tqdm requests pillow
!pip install "jax[cuda12_local]==0.5.3" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install wandb
!pip install accelerate
then, I see the following during installation:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.89 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
chex 0.1.89 requires jaxlib>=0.4.27, but you have jaxlib 0.4.23+cuda12.cudnn89 which is incompatible.
optax 0.2.4 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
optax 0.2.4 requires jaxlib>=0.4.27, but you have jaxlib 0.4.23+cuda12.cudnn89 which is incompatible.
flax 0.10.4 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
orbax-checkpoint 0.11.10 requires jax>=0.5.0, but you have jax 0.4.23 which is incompatible.
and when trying to import some of the diffusion modules:
from diffusers import StableDiffusionPipeline, DDIMScheduler
I see the following error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-5-5f3e3eccac57>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from diffusers import StableDiffusionPipeline, DDIMScheduler
12 frames
[/usr/local/lib/python3.11/dist-packages/diffusers/__init__.py](https://localhost:8080/#) in <module>
19 if is_torch_available():
20 from .modeling_utils import ModelMixin
---> 21 from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
22 from .optimization import (
23 get_constant_schedule,
[/usr/local/lib/python3.11/dist-packages/diffusers/models/__init__.py](https://localhost:8080/#) in <module>
24
25 if is_flax_available():
---> 26 from .unet_2d_condition_flax import FlaxUNet2DConditionModel
27 from .vae_flax import FlaxAutoencoderKL
[/usr/local/lib/python3.11/dist-packages/diffusers/models/unet_2d_condition_flax.py](https://localhost:8080/#) in <module>
14 from typing import Tuple, Union
15
---> 16 import flax
17 import flax.linen as nn
18 import jax
[/usr/local/lib/python3.11/dist-packages/flax/__init__.py](https://localhost:8080/#) in <module>
22 del configurations
23
---> 24 from flax import core
25 from flax import jax_utils
26 from flax import linen
[/usr/local/lib/python3.11/dist-packages/flax/core/__init__.py](https://localhost:8080/#) in <module>
22 unfreeze as unfreeze,
23 )
---> 24 from .lift import (
25 custom_vjp as custom_vjp,
26 jit as jit,
[/usr/local/lib/python3.11/dist-packages/flax/core/lift.py](https://localhost:8080/#) in <module>
25
26 from flax import traceback_util
---> 27 from flax import traverse_util
28 from flax.typing import (
29 In,
[/usr/local/lib/python3.11/dist-packages/flax/traverse_util.py](https://localhost:8080/#) in <module>
64
65 import flax
---> 66 from flax.core.scope import VariableDict
67 from flax.typing import PathParts
68
[/usr/local/lib/python3.11/dist-packages/flax/core/scope.py](https://localhost:8080/#) in <module>
53 )
54
---> 55 from . import meta, partial_eval, tracers
56 from .frozen_dict import FrozenDict, freeze, unfreeze
57
[/usr/local/lib/python3.11/dist-packages/flax/core/meta.py](https://localhost:8080/#) in <module>
186
187
--> 188 class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
189 """Wrapper for partitioning metadata.
190
/usr/lib/python3.11/abc.py in __new__(mcls, name, bases, namespace, **kwargs)
[/usr/local/lib/python3.11/dist-packages/flax/struct.py](https://localhost:8080/#) in __init_subclass__(cls, **kwargs)
233
234 def __init_subclass__(cls, **kwargs):
--> 235 dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
236
237 def __init__(self, *args, **kwargs):
[/usr/local/lib/python3.11/dist-packages/flax/struct.py](https://localhost:8080/#) in dataclass(clz, **kwargs)
148 data_clz.replace = replace
149
--> 150 jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields)
151
152 def to_state_dict(x):
[/usr/local/lib/python3.11/dist-packages/jax/_src/deprecations.py](https://localhost:8080/#) in getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
54
55 return getattr
AttributeError: module 'jax.tree_util' has no attribute 'register_dataclass'
I have tried upgrading jax to jax-0.5.3 and jaxlib-0.5.3, but then I see this error instead:
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
[<ipython-input-2-5f3e3eccac57>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from diffusers import StableDiffusionPipeline, DDIMScheduler
2 frames
[/usr/local/lib/python3.11/dist-packages/diffusers/__init__.py](https://localhost:8080/#) in <module>
29 get_scheduler,
30 )
---> 31 from .pipeline_utils import DiffusionPipeline
32 from .pipelines import (
33 DanceDiffusionPipeline,
[/usr/local/lib/python3.11/dist-packages/diffusers/pipeline_utils.py](https://localhost:8080/#) in <module>
33
34 from .configuration_utils import ConfigMixin
---> 35 from .dynamic_modules_utils import get_class_from_dynamic_module
36 from .hub_utils import http_user_agent
37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
[/usr/local/lib/python3.11/dist-packages/diffusers/dynamic_modules_utils.py](https://localhost:8080/#) in <module>
24 from typing import Dict, Optional, Union
25
---> 26 from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
27
28 from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
ImportError: cannot import name 'cached_download' from 'huggingface_hub' (/usr/local/lib/python3.11/dist-packages/huggingface_hub/__init__.py)
I also tried upgrading jax and diffusers (to diffusers-0.32.2) but then in the last cells I encounter:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[<ipython-input-12-1772ee7a7d70>](https://localhost:8080/#) in <cell line: 0>()
----> 1 embedding = optimize_embedding(
2 ldm,
3 num_steps=num_optimization_steps,
4 batch_size=batch_size,
5 top_k = num_keypoints,
10 frames
[<ipython-input-11-aead50517aea>](https://localhost:8080/#) in optimize_embedding(ldm, context, device, num_steps, from_where, upsample_res, layers, lr, noise_level, num_tokens, top_k, augment_degrees, augment_scale, augment_translate, dataset_loc, sigma, sharpening_loss_weight, equivariance_attn_loss_weight, batch_size, num_gpus, max_len, min_dist, furthest_point_num_samples, controllers, validation, num_subjects)
76 image = mini_batch["img"]
77
---> 78 attn_maps = run_and_find_attn(
79 ldm,
80 image,
[<ipython-input-10-d782888e078d>](https://localhost:8080/#) in run_and_find_attn(ldm, image, context, noise_level, device, from_where, layers, upsample_res, indices, controllers)
113 controllers=None,
114 ):
--> 115 _, _ = find_pred_noise(
116 ldm,
117 image,
[<ipython-input-10-d782888e078d>](https://localhost:8080/#) in find_pred_noise(ldm, image, context, noise_level, device)
41 # import ipdb; ipdb.set_trace()
42
---> 43 pred_noise = ldm.unet(noisy_image,
44 ldm.scheduler.timesteps[noise_level].repeat(noisy_image.shape[0]),
45 context.repeat(noisy_image.shape[0], 1, 1))["sample"]
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
[/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/data_parallel.py](https://localhost:8080/#) in forward(self, *inputs, **kwargs)
189
190 if len(self.device_ids) == 1:
--> 191 return self.module(*inputs[0], **module_kwargs[0])
192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
193 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1843
1844 try:
-> 1845 return inner()
1846 except Exception:
1847 # run always called hooks if they have not already been run
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in inner()
1780 )
1781 else:
-> 1782 args_result = hook(self, args)
1783 if args_result is not None:
1784 if not isinstance(args_result, tuple):
[<ipython-input-6-72ab3767eb92>](https://localhost:8080/#) in hook_fn(module, input)
200 _device = input[0].device
201 # if device not in patched_devices:
--> 202 register_attention_control(module, controllers[_device], feature_upsample_res=feature_upsample_res)
203 # patched_devices.add(device)
204
[<ipython-input-6-72ab3767eb92>](https://localhost:8080/#) in register_attention_control(model, controller, feature_upsample_res)
160
161 # create assertion with message
--> 162 assert cross_att_count != 0, "No cross attention layers found in the model. Please check to make sure you're using diffusers==0.8.0."
163
164 def load_ldm(device, type="CompVis/stable-diffusion-v1-4", feature_upsample_res=256):
AssertionError: No cross attention layers found in the model. Please check to make sure you're using diffusers==0.8.0.
Can you please recommend which versions to use for each package? Thanks!!!!
Metadata
Metadata
Assignees
Labels
No labels