support for encoder from gemma-3-4b-it#42
Conversation
llava/model/llava_arch.py
Outdated
|
|
||
| # ==================================================================================== | ||
| # ================== FIX 2: Brute-force clip the features ============================ | ||
| image_features = torch.clamp(image_features, min=-10.0, max=10.0) |
There was a problem hiding this comment.
we need to do this only for siglip from gemma3 to make sure others are unchanged?
There was a problem hiding this comment.
yes, i will add a check to do it only for vision towers from Gemma
|
|
||
| class LLaVATrainer(Trainer): | ||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
There was a problem hiding this comment.
for reference, this is how compute_loss looks like in HF trainer class https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618
|
|
||
| class LLaVATrainer(Trainer): | ||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
There was a problem hiding this comment.
can we also make sure this compute_loss is bypassed in case of other encoders. I am still not sure why exactly we need a custom loss here. Is it to handle sequence mismatch problem? If so why is it not a problem in other encoders? Or is there any other reason? So unless those are clear, lets make sure we do this custom compute_loss only for gemma3-siglip
There was a problem hiding this comment.
in your if 'siglip' in self.data_args.image_processor.image_processor_type.lower(): line, you need to update to
if 'siglip' or 'gemma' in self.data_args.image_processor.image_processor_type.lower():
This PR adds support for using the
gemma3-siglip-encoder(fromgoogle/gemma-3-4b-it) as a vision tower for LLaVA pre-training with a Vicuna-based LLM.1. Numerical Instability issue - NaN loss
Initial attempts to pre-train using the
gemma3-siglip-encoderresulted in a persistentNaNloss. Debugging revealed that the encoder produces feature outputs with an extremely large numerical magnitude. This triggered a low-level bug deep inside the language model'sCrossEntropyLossfunction, causing it to fail even when all inputs (logitsandlabels) were valid.2. Implementation Details
To enable stable training, the following two-part solution was implemented:
Feature Clipping: A
torch.clampfunction was added to theencode_imagesmethod inllava/model/llava_arch.py. This controls the extreme magnitude of thegemmafeatures by ensuring they are within a stable[-10, 10]range before being passed to the language model.Manual Loss Calculation: The
compute_lossmethod inllava/train/llava_trainer.pywas overridden to bypass the model's unstable internal loss function. This implementation takes the cleanlogitsfrom the model and performs a stable, manualCrossEntropyLosscalculation.