Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ pip install -e .
pip install -e ".[train]"
pip install flash-attn==2.6.3 --no-build-isolation --no-cache-dir
```
4. Bump transformers version for Siglip2-Naflex (4.50.0) and Aimv2 (4.54.0). This also leads to bumping accelerate version to 1.3.0 since the `unwrap_model` function of accelerate introduced `keep_torch_compile` parameter which is expected by the newer transformers

```
pip install -U transformers==4.54.0
pip install -U accelerate==1.3.0
```

## Model Weights and Dataset
[HuggingFace](https://huggingface.co/maya-multimodal)
Expand Down
9,148 changes: 9,148 additions & 0 deletions llava-siglip2-naflex.log

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions llava/model/multimodal_encoder/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from transformers.trainer import logger
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
from .siglip_encoder import SiglipVisionTower
from .siglip2_encoder import Siglip2VisionTower

def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
Expand All @@ -11,6 +13,10 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif 'naflex' in vision_tower:
logger.info(f"naflex version with {vision_tower} found")
return Siglip2VisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif 'siglip' in vision_tower or 'gemma' in vision_tower:
logger.info(f"siglip was in {vision_tower} found")
return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
177 changes: 177 additions & 0 deletions llava/model/multimodal_encoder/siglip2_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
'''
This file only handles naflex variant of Siglip2. The regulat Siglip2 is handled the same way as Siglip
'''

import torch
import torch.nn as nn

from transformers import Siglip2VisionModel, Siglip2ImageProcessor, Siglip2VisionConfig

class Siglip2VisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()

self.is_loaded = False

self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')

if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = Siglip2VisionConfig.from_pretrained(self.vision_tower_name)

def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return

self.image_processor = Siglip2ImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = Siglip2VisionModel.from_pretrained(self.vision_tower_name, device_map=device_map, ignore_mismatched_sizes=True)
self.vision_tower.requires_grad_(False)

self.is_loaded = True

def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features

@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
# Ensure input is float32 before image processor
# images = images.to(dtype=torch.float32)
# Preprocess image to get pixel_values, pixel_attention_mask, spatial_shapes
batch_feature = self.image_processor(image.unsqueeze(0))
pixel_values = batch_feature["pixel_values"].to(device=self.device, dtype=self.dtype)
pixel_attention_mask = batch_feature["pixel_attention_mask"].to(device=self.device)
spatial_shapes = batch_feature["spatial_shapes"].to(device=self.device)

# Call vision model
image_forward_out = self.vision_tower(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_hidden_states=True,
)

image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
batch_feature = self.image_processor(images)
pixel_values = batch_feature["pixel_values"].to(device=self.device, dtype=self.dtype)
pixel_attention_mask = batch_feature["pixel_attention_mask"].to(device=self.device)
spatial_shapes = batch_feature["spatial_shapes"].to(device=self.device)

image_forward_outs = self.vision_tower(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_hidden_states=True,
)

image_features = self.feature_select(image_forward_outs).to(images.dtype)

return image_features

@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

@property
def dtype(self):
return self.vision_tower.dtype

@property
def device(self):
return self.vision_tower.device

@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only

@property
def hidden_size(self):
return self.config.hidden_size

@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size

@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2



class Siglip2VisionTowerS2(Siglip2VisionTower):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__(vision_tower, args, delay_load)

self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
self.s2_scales = list(map(int, self.s2_scales.split(',')))
self.s2_scales.sort()
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]

try:
from s2wrapper import forward as multiscale_forward
except ImportError:
raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
self.multiscale_forward = multiscale_forward

# change resize/crop size in preprocessing to the largest image size in s2_scale
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size

def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return

self.image_processor = Siglip2ImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = Siglip2VisionModel.from_pretrained(self.vision_tower_name, device_map=device_map, ignore_mismatched_sizes=True)
self.vision_tower.requires_grad_(False)

self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size

self.is_loaded = True

@torch.no_grad()
def forward_feature(self, images):
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features

@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
image_features.append(image_feature)
else:
image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)

return image_features

@property
def hidden_size(self):
return self.config.hidden_size * len(self.s2_scales)

4 changes: 2 additions & 2 deletions llava/model/multimodal_encoder/siglip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def load_model(self, device_map=None):
return

self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map, ignore_mismatched_sizes=True)
self.vision_tower.requires_grad_(False)

self.is_loaded = True
Expand Down Expand Up @@ -115,7 +115,7 @@ def load_model(self, device_map=None):
return

self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map, ignore_mismatched_sizes=True)
self.vision_tower.requires_grad_(False)

self.image_processor.size['shortest_edge'] = self.s2_image_size
Expand Down
4 changes: 2 additions & 2 deletions llava/train/llava_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
)
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from typing import List, Optional


Expand Down Expand Up @@ -132,7 +132,7 @@ def __iter__(self):

class LLaVATrainer(Trainer):

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
def _get_train_sampler(self, dataset) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None

Expand Down
5 changes: 5 additions & 0 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def expand2square(pil_img, background_color):
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
#crop_size = self.data_args.image_processor.crop_size or {'height': 224, 'width': 224}
# TODO: How to handle siglip naflex here for crop size? just keep it same as other siglip? since it is variable length path size, does it matter?
if 'siglip' in self.data_args.image_processor.image_processor_type.lower():
crop_size = {'height': 256, 'width': 256}
else:
Expand Down Expand Up @@ -876,6 +877,8 @@ def train(attn_implementation=None):
elif 'aya' in model_args.model_name_or_path:
model = LlavaCohereForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
use_safetensors=True,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
Expand All @@ -884,6 +887,8 @@ def train(attn_implementation=None):
else:
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
use_safetensors=True,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
Expand Down
2 changes: 1 addition & 1 deletion scripts/v1_5/pretrain_llava_siglip2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--version plain \
--data_path ./dev/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
--data_path ./dev/data/LLaVA_Pretrain/blip_laion_cc_sbu_558k.json \
--image_folder ./dev/data/images \
--vision_tower google/siglip2-base-patch16-256 \
--mm_projector_type mlp2x_gelu \
Expand Down
6 changes: 3 additions & 3 deletions scripts/v1_5/pretrain_llava_siglip2_naflex.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--version plain \
--data_path ./dev/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
--image_folder ./dev/data/images \
--data_path /dev/data/LLaVA_Pretrain/blip_laion_cc_sbu_558k.json \
--image_folder /dev/data/images \
--vision_tower google/siglip2-so400m-patch16-naflex \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
Expand All @@ -18,7 +18,7 @@ deepspeed llava/train/train_mem.py \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 24000 \
--save_total_limit 1 \
Expand Down