Skip to content

vllm.model_executor.models.musicflamingo

MusicFlamingoForConditionalGeneration

Bases: AudioFlamingo3ForConditionalGeneration

vLLM MusicFlamingo model aligned with HF modular_musicflamingo.

Source code in vllm/model_executor/models/musicflamingo.py
@MULTIMODAL_REGISTRY.register_processor(
    MusicFlamingoMultiModalProcessor,
    info=MusicFlamingoProcessingInfo,
    dummy_inputs=MusicFlamingoDummyInputsBuilder,
)
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
    """vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
        self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
        self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)

    def _process_audio_input(
        self, audio_input: MusicFlamingoInputs
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
        if audio_input["type"] == "audio_embeds":
            return super()._process_audio_input(audio_input)

        (
            input_features,
            feature_attention_mask,
            chunk_counts,
        ) = self._normalize_audio_feature_inputs(audio_input)
        hidden_states = self._encode_audio_features(
            input_features,
            feature_attention_mask,
        )
        audio_timestamps = _build_audio_timestamps(
            feature_attention_mask,
            chunk_counts,
            hidden_states.shape[-2],
            self.config.audio_frame_step,
        )
        cos, sin = self.pos_emb(
            audio_timestamps.to(hidden_states.device),
            seq_len=hidden_states.shape[-2],
        )
        hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
        audio_features = self.multi_modal_projector(hidden_states)

        return self._group_audio_embeddings(
            audio_features,
            feature_attention_mask,
            chunk_counts,
        )