From 0016067709fc6c8bf49f78d425c57229bddde4dd Mon Sep 17 00:00:00 2001 From: dengcunqin Date: Sat, 17 Jan 2026 18:12:52 +0800 Subject: [PATCH 1/2] paraformer_v2_community offline This PR provides a code reproduction of Paraformer-v2: An Improved Non-Autoregressive Transformer for Noise-Robust Speech Recognition, as described in the paper. The implementation may not be fully consistent with the official code but aims to replicate the core concepts and methods presented in the paper. --- .../paraformer_v2_community/__init__.py | 0 .../models/paraformer_v2_community/decoder.py | 590 ++++++++++++++++++ .../models/paraformer_v2_community/model.py | 547 ++++++++++++++++ .../paraformer_v2_community/template.yaml | 112 ++++ 4 files changed, 1249 insertions(+) create mode 100644 funasr/models/paraformer_v2_community/__init__.py create mode 100644 funasr/models/paraformer_v2_community/decoder.py create mode 100644 funasr/models/paraformer_v2_community/model.py create mode 100644 funasr/models/paraformer_v2_community/template.yaml diff --git a/funasr/models/paraformer_v2_community/__init__.py b/funasr/models/paraformer_v2_community/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/paraformer_v2_community/decoder.py b/funasr/models/paraformer_v2_community/decoder.py new file mode 100644 index 000000000..556873db5 --- /dev/null +++ b/funasr/models/paraformer_v2_community/decoder.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import torch +from typing import List, Tuple + +from funasr.register import tables +from funasr.models.scama import utils as myutils +from funasr.models.transformer.utils.repeat import repeat +from funasr.models.transformer.decoder import DecoderLayer +from funasr.models.transformer.layer_norm import LayerNorm +from funasr.models.transformer.embedding import PositionalEncoding +from funasr.models.transformer.attention import MultiHeadedAttention +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.transformer.decoder import BaseTransformerDecoder +from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward +from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM +from funasr.models.sanm.attention import ( + MultiHeadedAttentionSANMDecoder, + MultiHeadedAttentionCrossAtt, +) + + +class DecoderLayerSANM(torch.nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + + """ + + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayerSANM, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + if self_attn is not None: + self.norm2 = LayerNorm(size) + if src_attn is not None: + self.norm3 = LayerNorm(size) + self.dropout = torch.nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = torch.nn.Linear(size + size, size) + self.concat_linear2 = torch.nn.Linear(size + size, size) + self.reserve_attn = False + self.attn_mat = [] + + def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + # tgt = self.dropout(tgt) + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + + x = tgt + if self.self_attn: + if self.normalize_before: + tgt = self.norm2(tgt) + x, _ = self.self_attn(tgt, tgt_mask) + x = residual + self.dropout(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm3(x) + if self.reserve_attn: + x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True) + self.attn_mat.append(attn_mat) + else: + x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False) + x = residual + self.dropout(x_src_attn) + # x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) + + return x, tgt_mask, memory, memory_mask, cache + + def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + residual = tgt + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + + x = tgt + if self.self_attn is not None: + tgt = self.norm2(tgt) + x, cache = self.self_attn(tgt, tgt_mask, cache=cache) + x = residual + x + + residual = x + x = self.norm3(x) + x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True) + return attn_mat + + def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + # tgt = self.dropout(tgt) + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + + x = tgt + if self.self_attn: + if self.normalize_before: + tgt = self.norm2(tgt) + if self.training: + cache = None + x, cache = self.self_attn(tgt, tgt_mask, cache=cache) + x = residual + self.dropout(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm3(x) + + x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) + + return x, tgt_mask, memory, memory_mask, cache + + def forward_chunk( + self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0 + ): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + + x = tgt + if self.self_attn: + if self.normalize_before: + tgt = self.norm2(tgt) + x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache) + x = residual + self.dropout(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm3(x) + + x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back) + x = residual + x + + return x, memory, fsmn_cache, opt_cache + + +@tables.register("decoder_classes", "ParaformerSANMDecoder_v2_community") +class ParaformerSANMDecoder(BaseTransformerDecoder): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2006.01713 + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + wo_input_layer: bool = False, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + att_layer_num: int = 6, + kernel_size: int = 21, + sanm_shfit: int = 0, + lora_list: List[str] = None, + lora_rank: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.1, + chunk_multiply_factor: tuple = (1,), + tf2torch_tensor_name_prefix_torch: str = "decoder", + tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder", + ): + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + if wo_input_layer: + self.embed = None + else: + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + # pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = None + + self.att_layer_num = att_layer_num + self.num_blocks = num_blocks + if sanm_shfit is None: + sanm_shfit = (kernel_size - 1) // 2 + self.decoders = repeat( + att_layer_num, + lambda lnum: DecoderLayerSANM( + attention_dim, + MultiHeadedAttentionSANMDecoder( + attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit + ), + MultiHeadedAttentionCrossAtt( + attention_heads, + attention_dim, + src_attention_dropout_rate, + lora_list, + lora_rank, + lora_alpha, + lora_dropout, + ), + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if num_blocks - att_layer_num <= 0: + self.decoders2 = None + else: + self.decoders2 = repeat( + num_blocks - att_layer_num, + lambda lnum: DecoderLayerSANM( + attention_dim, + MultiHeadedAttentionSANMDecoder( + attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 + ), + None, + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.decoders3 = repeat( + 1, + lambda lnum: DecoderLayerSANM( + attention_dim, + None, + None, + PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + self.chunk_multiply_factor = chunk_multiply_factor + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + chunk_mask: torch.Tensor = None, + return_hidden: bool = False, + return_both: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + if chunk_mask is not None: + memory_mask = memory_mask * chunk_mask + if tgt_mask.size(1) != memory_mask.size(1): + memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) + + x = self.embed(tgt) + x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask) + if self.decoders2 is not None: + x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask) + x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask) + if self.normalize_before: + hidden = self.after_norm(x) + + olens = tgt_mask.sum(1) + if self.output_layer is not None and return_hidden is False: + x = self.output_layer(hidden) + return x, olens + if return_both: + x = self.output_layer(hidden) + return x, hidden, olens + return hidden, olens + + def score(self, ys, state, x): + """Score.""" + ys_mask = myutils.sequence_mask( + torch.tensor([len(ys)], dtype=torch.int32), device=x.device + )[:, :, None] + logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) + return logp.squeeze(0), state + + def forward_asf2( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ): + + tgt = ys_in_pad + tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + + tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask) + attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask) + return attn_mat + + def forward_asf6( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ): + + tgt = ys_in_pad + tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + + tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask) + tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask) + tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask) + tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask) + tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask) + attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask) + return attn_mat + + def forward_chunk( + self, + memory: torch.Tensor, + tgt: torch.Tensor, + cache: dict = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + x = tgt + if cache["decode_fsmn"] is None: + cache_layer_num = len(self.decoders) + if self.decoders2 is not None: + cache_layer_num += len(self.decoders2) + fsmn_cache = [None] * cache_layer_num + else: + fsmn_cache = cache["decode_fsmn"] + + if cache["opt"] is None: + cache_layer_num = len(self.decoders) + opt_cache = [None] * cache_layer_num + else: + opt_cache = cache["opt"] + + for i in range(self.att_layer_num): + decoder = self.decoders[i] + x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk( + x, + memory, + fsmn_cache=fsmn_cache[i], + opt_cache=opt_cache[i], + chunk_size=cache["chunk_size"], + look_back=cache["decoder_chunk_look_back"], + ) + + if self.num_blocks - self.att_layer_num > 1: + for i in range(self.num_blocks - self.att_layer_num): + j = i + self.att_layer_num + decoder = self.decoders2[i] + x, memory, fsmn_cache[j], _ = decoder.forward_chunk( + x, memory, fsmn_cache=fsmn_cache[j] + ) + + for decoder in self.decoders3: + x, memory, _, _ = decoder.forward_chunk(x, memory) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + + cache["decode_fsmn"] = fsmn_cache + if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1: + cache["opt"] = opt_cache + return x + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + + Args: + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + memory: encoded memory, float32 (batch, maxlen_in, feat) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x = self.embed(tgt) + if cache is None: + cache_layer_num = len(self.decoders) + if self.decoders2 is not None: + cache_layer_num += len(self.decoders2) + cache = [None] * cache_layer_num + new_cache = [] + # for c, decoder in zip(cache, self.decoders): + for i in range(self.att_layer_num): + decoder = self.decoders[i] + c = cache[i] + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( + x, tgt_mask, memory, None, cache=c + ) + new_cache.append(c_ret) + + if self.num_blocks - self.att_layer_num > 1: + for i in range(self.num_blocks - self.att_layer_num): + j = i + self.att_layer_num + decoder = self.decoders2[i] + c = cache[j] + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( + x, tgt_mask, memory, None, cache=c + ) + new_cache.append(c_ret) + + for decoder in self.decoders3: + + x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step( + x, tgt_mask, memory, None, cache=None + ) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.output_layer is not None: + y = torch.log_softmax(self.output_layer(y), dim=-1) + + return y, new_cache + diff --git a/funasr/models/paraformer_v2_community/model.py b/funasr/models/paraformer_v2_community/model.py new file mode 100644 index 000000000..276dec9f1 --- /dev/null +++ b/funasr/models/paraformer_v2_community/model.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import time +import copy +import torch +import logging +from torch.cuda.amp import autocast +from typing import Union, Dict, List, Tuple, Optional + +from funasr.register import tables +from funasr.models.ctc.ctc import CTC +from funasr.utils import postprocess_utils +from funasr.metrics.compute_acc import th_accuracy +from funasr.train_utils.device_funcs import to_device +from funasr.utils.datadir_writer import DatadirWriter +from funasr.models.paraformer.search import Hypothesis +from funasr.models.paraformer.cif_predictor import mae_loss +from funasr.train_utils.device_funcs import force_gatherable +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from torch.nn.utils.rnn import pad_sequence +import torchaudio + +@tables.register("model_classes", "Paraformer_v2_community") +class Paraformer(torch.nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, + specaug: Optional[str] = None, + specaug_conf: Optional[Dict] = None, + normalize: str = None, + normalize_conf: Optional[Dict] = None, + encoder: str = None, + encoder_conf: Optional[Dict] = None, + decoder: str = None, + decoder_conf: Optional[Dict] = None, + ctc: str = None, + ctc_conf: Optional[Dict] = None, + ctc_weight: float = 0.5, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + # report_cer: bool = True, + # report_wer: bool = True, + # sym_space: str = "", + # sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + # predictor=None, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + use_1st_decoder_loss: bool = False, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + if decoder is not None: + decoder_class = tables.decoder_classes.get(decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **decoder_conf, + ) + if ctc_weight > 0.0: + + if ctc_conf is None: + ctc_conf = {} + + ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) + + # note that eos is the same as sos (equivalent ID) + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + # self.token_list = token_list.copy() + # + # self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + # self.preencoder = preencoder + # self.postencoder = postencoder + self.encoder = encoder + # + # if not hasattr(self.encoder, "interctc_use_conditioning"): + # self.encoder.interctc_use_conditioning = False + # if self.encoder.interctc_use_conditioning: + # self.encoder.conditioning_layer = torch.nn.Linear( + # vocab_size, self.encoder.output_size() + # ) + # + # self.error_calculator = None + # + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + # + # if report_cer or report_wer: + # self.error_calculator = ErrorCalculator( + # token_list, sym_space, sym_blank, report_cer, report_wer + # ) + # + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + # + # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + self.share_embedding = share_embedding + if self.share_embedding: + self.decoder.embed = None + + self.use_1st_decoder_loss = use_1st_decoder_loss + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + self.error_calculator = None + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_ctc, cer_ctc = None, None + loss_pre = None + stats = dict() + + # decoder: CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # decoder: Attention decoder branch + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + else: + loss = ( + self.ctc_weight * loss_ctc + + (1 - self.ctc_weight) * loss_att + ) + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + + stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths).sum() + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + return encoder_out, encoder_out_lens + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + + # 0. sampler + decoder_out_1st = None + + batch_size = encoder_out.size(0) + ctc_probs_all = self.ctc.softmax(encoder_out) + compressed_ctc_list = [] + for b in range(batch_size): + ctc_prob_b = ctc_probs_all[b, :encoder_out_lens[b]] + text_b = ys_pad[b, :ys_pad_lens[b]] + with torch.no_grad(): + ctc_log_prob_b = ctc_prob_b.log() + align_path = self.force_align(ctc_log_prob_b.cpu(), text_b.cpu(), blank_id=self.blank_id) + align_path = align_path.to(encoder_out.device) + target_idx_path = self.map_alignment_to_target_index(align_path, self.blank_id) + ctc_comp = self.average_repeats_training(ctc_prob_b, target_idx_path, ys_pad_lens[b]) + compressed_ctc_list.append(ctc_comp) + + # 4. Pad Batch to [B, U_max, V] + padded_ctc_input = pad_sequence(compressed_ctc_list, batch_first=True).to(encoder_out.device) + + + + # 1. Forward decoder + decoder_outs = self.decoder(encoder_out, encoder_out_lens, padded_ctc_input, ys_pad_lens) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def map_alignment_to_target_index(self, align_path, blank_id): + """ + Robustly map CTC alignment path (Token IDs) to Target Indices. + + Logic: + Detect boundaries where a new token segment begins. + A segment starts if the current frame is a Token AND it is different from the previous frame + (considering CTC topology where repeats are separated by blanks or are distinct tokens). + + Example: + Text: [A, B] + Align Path: [A, A, _, B, B] + Output: [0, 0, -1, 1, 1] + """ + # 1. Identify where the path is NOT blank + is_token = align_path != blank_id + + # 2. Identify transitions + prev_path = torch.roll(align_path, 1) + # Handle the very first frame: if it's a token, it must be the start of segment 0. + prev_path[0] = blank_id # force mismatch for the first element + + # A new segment starts if: It's a token AND (it differs from prev OR prev was blank) + # Note: If align_path[i] == align_path[i-1] (and not blank), it's the same segment. + new_segment_start = is_token & (align_path != prev_path) + + # 3. Cumulative sum to assign indices (1..U) + segment_ids = torch.cumsum(new_segment_start.long(), dim=0) - 1 + + # 4. Mask out blank positions with -1 + target_idx_path = torch.where(is_token, segment_ids, -1) + + return target_idx_path + def force_align(self, ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: + """ctc forced alignment. + + Args: + torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) + torch.Tensor y: id sequence tensor 1d tensor (L) + int blank_id: blank symbol index + Returns: + torch.Tensor: alignment result + """ + ctc_probs = ctc_probs[None].cpu() + y = y[None].cpu() + alignments, _ = torchaudio.functional.forced_align(ctc_probs, y, blank=blank_id) + return alignments[0] + + def average_repeats_training(self, ctc_probs, target_idx_path, target_len): + """ + Aggregates frames belonging to the same target index using scatter_add. + + Args: + ctc_probs: [T, V] + target_idx_path: [T], values in [-1, 0, ... U-1] + target_len: U + Returns: + compressed: [U, V] + """ + U = target_len + V = ctc_probs.size(1) + + compressed = torch.zeros((U, V), device=ctc_probs.device, dtype=ctc_probs.dtype) + counts = torch.zeros((U, 1), device=ctc_probs.device, dtype=ctc_probs.dtype) + + # Filter valid frames (non-blank) + mask = target_idx_path != -1 + valid_indices = target_idx_path[mask] # [T_valid] + valid_probs = ctc_probs[mask] # [T_valid, V] + + if valid_indices.numel() == 0: + return compressed + + # Scatter Add Probs + index_expanded = valid_indices.unsqueeze(1).repeat(1, V) + compressed.scatter_add_(0, index_expanded, valid_probs) + + # Scatter Add Counts + ones = torch.ones((valid_indices.size(0), 1), device=ctc_probs.device) + counts.scatter_add_(0, valid_indices.unsqueeze(1), ones) + + # Average + compressed = compressed / (counts + 1e-9) + return compressed + + def average_repeats_inference(self, ctc_probs, greedy_path): + """ + Returns: + merged_probs: [U', V] + timestamps: List[Tuple[int, int]] -> [(start_frame, end_frame), ...] + """ + if greedy_path.numel() == 0: + return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device), [] + + # Find consecutive segments in the greedy path + unique_tokens, counts = torch.unique_consecutive(greedy_path, return_counts=True) + + # Compute start and end indices for each segment + end_indices = torch.cumsum(counts, dim=0) + start_indices = torch.cat([torch.tensor([0], device=counts.device), end_indices[:-1]]) + + merged_probs = [] + + for i, token in enumerate(unique_tokens): + if token != self.blank_id: + start = start_indices[i].item() + end = end_indices[i].item() + + # Extract and average probabilities for the decoder + avg_prob = ctc_probs[start:end].mean(dim=0) + merged_probs.append(avg_prob) + + + if not merged_probs: + return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device) + + return torch.stack(merged_probs) + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + meta_data = {} + if ( + isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" + ): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is not None: + speech_lengths = speech_lengths.squeeze(-1) + else: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video( + data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + ) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank( + audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend + ) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = ( + speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + ) + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + # Encoder + if kwargs.get("fp16", False): + speech = speech.half() + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + ctc_probs = self.ctc.softmax(encoder_out) + ctc_greedy_paths = ctc_probs.argmax(dim=-1) + + results = [] + batch_size, n, d = encoder_out.size() + if isinstance(key[0], (list, tuple)): + key = key[0] + if len(key) < batch_size: + key = key * batch_size + for b in range(batch_size): + + probs = ctc_probs[b, :encoder_out_lens[b]] + path = ctc_greedy_paths[b, :encoder_out_lens[b]] + + # Get compressed probabilities and timestamp indices + compressed_prob = self.average_repeats_inference(probs, path) + + # Handling Noise/Silence (Empty Output) + if compressed_prob.size(0) == 0: + token_int = [] + timestamp_list = [] + else: + # 4. Decoder Forward + compressed_prob_in = compressed_prob.unsqueeze(0) # [1, U', V] + in_lens = torch.tensor([compressed_prob.size(0)], device=encoder_out.device) + + decoder_out, _ = self.decoder( + encoder_out[b:b+1], + encoder_out_lens[b:b+1], + compressed_prob_in, + in_lens,) + + + yseq = decoder_out.argmax(dim=-1)[0] + token_int = yseq.tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list( + filter( + lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int + ) + ) + + result_i = {"key": key[b], "token_int": token_int} + results.append(result_i) + + return results, meta_data + + def export(self, **kwargs): + from .export_meta import export_rebuild_model + + if "max_seq_len" not in kwargs: + kwargs["max_seq_len"] = 512 + models = export_rebuild_model(model=self, **kwargs) + return models diff --git a/funasr/models/paraformer_v2_community/template.yaml b/funasr/models/paraformer_v2_community/template.yaml new file mode 100644 index 000000000..aa58f1288 --- /dev/null +++ b/funasr/models/paraformer_v2_community/template.yaml @@ -0,0 +1,112 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: Paraformer_v2_community +model_conf: + ctc_weight: 0.5 + lsm_weight: 0.1 + length_normalized_loss: true + +# encoder +encoder: SANMEncoder +encoder_conf: + output_size: 512 + attention_heads: 4 + linear_units: 2048 + num_blocks: 50 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + +# decoder +decoder: ParaformerSANMDecoder_v2_community +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 16 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 16 + kernel_size: 11 + sanm_shfit: 0 + input_layer: linear + +# frontend related +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + keep_nbest_models: 10 + avg_nbest_model: 10 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: BatchSampler + batch_type: example # example or length + batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +normalize: null From 7f39f6298a225164af092ae5954feadeb70c737a Mon Sep 17 00:00:00 2001 From: dengcunqin Date: Sat, 17 Jan 2026 18:43:01 +0800 Subject: [PATCH 2/2] Fix return statement for empty greedy path case --- funasr/models/paraformer_v2_community/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/funasr/models/paraformer_v2_community/model.py b/funasr/models/paraformer_v2_community/model.py index 276dec9f1..2d3ac9070 100644 --- a/funasr/models/paraformer_v2_community/model.py +++ b/funasr/models/paraformer_v2_community/model.py @@ -413,7 +413,7 @@ def average_repeats_inference(self, ctc_probs, greedy_path): timestamps: List[Tuple[int, int]] -> [(start_frame, end_frame), ...] """ if greedy_path.numel() == 0: - return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device), [] + return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device) # Find consecutive segments in the greedy path unique_tokens, counts = torch.unique_consecutive(greedy_path, return_counts=True) @@ -510,7 +510,6 @@ def inference( # Handling Noise/Silence (Empty Output) if compressed_prob.size(0) == 0: token_int = [] - timestamp_list = [] else: # 4. Decoder Forward compressed_prob_in = compressed_prob.unsqueeze(0) # [1, U', V]