@MULTIMODAL_REGISTRY.register_processor(
MusicFlamingoMultiModalProcessor,
info=MusicFlamingoProcessingInfo,
dummy_inputs=MusicFlamingoDummyInputsBuilder,
)
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
"""vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)
def _process_audio_input(
self, audio_input: MusicFlamingoInputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
return super()._process_audio_input(audio_input)
(
input_features,
feature_attention_mask,
chunk_counts,
) = self._normalize_audio_feature_inputs(audio_input)
hidden_states = self._encode_audio_features(
input_features,
feature_attention_mask,
)
audio_timestamps = _build_audio_timestamps(
feature_attention_mask,
chunk_counts,
hidden_states.shape[-2],
self.config.audio_frame_step,
)
cos, sin = self.pos_emb(
audio_timestamps.to(hidden_states.device),
seq_len=hidden_states.shape[-2],
)
hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
audio_features = self.multi_modal_projector(hidden_states)
return self._group_audio_embeddings(
audio_features,
feature_attention_mask,
chunk_counts,
)