scripts
utils
examples : 使用方法,參考案例
...
src/transformers (Transformer相關的代碼)
data: 數據處理
models : 模型的實現(xiàn)代碼,比如BERT, GPT,Whisper模型,都在此目錄下實現(xiàn)
generation : 文本生成相關代碼

...

從上面的代碼中,examples中提供了模型的使用方法的參考例子。
我們的今天介紹的主要內容都在 src/transformers 目錄下,其中 models 目錄下,是基于transformer的各種模型的實現(xiàn)代碼,Generation 包含通用的文本產生的實現(xiàn)代碼。

模型 models/whisper

我們以Whisper 模型為例來詳細介紹一下代碼的結構和調用關系。下面我們以v4.29.1的版本為例進行介紹。
首先,whisper模型的代碼位于:src/transformers/models/whisper 目錄下。其主要功能都封裝在 modeling_whisper.py 文件中。

調用入口:WhisperForConditionalGeneration類
此python文件中包含多個類,繼承的關系比較復雜,它們之間的主要調用關系如下(以greedy search為例):

WhisperForConditionalGeneration (L1312) : 調用入口類
forward() (L1359) --> 細節(jié)在:WhisperModel,WhisperEncoder,WhisperDecoder 類中實現(xiàn)
generate() (L1455) --> 細節(jié)在: generation/utils.py#L1146 中實現(xiàn)
greedy_search(): L2164 --> 調用 search 函數來做實際的處理,比如自回歸處理

Forward函數:
forward函數位于 class transformers.WhisperModel 類中,代碼位置請參考:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1215

def forward():
# Encoder將輸入的語音信號,編碼為聲學信息,也就是 encoder_outputs
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Decoder 的主要輸入為 decoder_input_ids (對應文本) 和 encoder_outputs (對應聲學信息,在翻譯任務中,對應著源語言)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

通過上面,我們可以看到有 Encoder 部分和 Decoder 部分,分別對應聲學特征的提取和文本產生部分。

在 WhisperForConditionalGeneration 類中,也有一個forward函數,是對上面forward函數的封裝。
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1359

其中self.encoder的實現(xiàn)代碼位于 class WhisperEncoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L735

其中self.decoder的實現(xiàn)代碼位于 class WhisperDecoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L881

Generate函數
Generate函數入口:位于 class WhisperForConditionalGeneration(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455

此處只是調用的入口,具體的實現(xiàn)代碼位于 class GenerationMixin 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
def generate() L1146

其中generate函數使用的greedy_search的實現(xiàn)位于:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L2164

Generate 代碼分析

下面,我們來進一步了解 generate 的實現(xiàn)代碼,來看看如何對此代碼進行修改。

入口代碼
Generate函數的入口位于: WhisperForConditionalGeneration類中的 def generate 函數
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455

代碼的概要如下,從代碼中可以看到,這個函數主要是進行了一些參數設置,具體的實現(xiàn)是調用了父類中的對應函數來執(zhí)行的。

def generate()
# 參數設置部分

# 調用部分(此處調用了父類中的generate實現(xiàn))
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)

然后,我們可以逐級向上搜索其父類,可以看到

到此為止,我們就可以看到,具體的實現(xiàn)都在 GenerationMixin 類中。

Generate函數實現(xiàn)細節(jié)

下面,我們來看一下 GenerationMixin類中的 generate 函數的實現(xiàn)細節(jié)。
代碼位置:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146

其代碼概要如下:

def generate(): L1146
# 根據解碼方式的不同,此函數中有最多14步的處理步驟,我們以greedy search為例
# 1. Handle generation_config and kwargs that might update it, and validate the .generate() call # 2. Set generation parameters if not already defined # 3. Define model inputs # 4. Define other model kwargs # 5. Prepare input_ids which will be used for auto-regressive generation # 6. Prepare max_length depending on other stopping criteria. # 7. determine generation mode # 8. prepare distribution pre_processing samplers # 9. prepare stopping criteria # 10. go into different generation modes # 11. run greedy search (L1515) def greedy_search(): L2164 # 初始化,設置 # 循環(huán)處理 while True: # L2317 # prepare model inputs (下面函數的具體實現(xiàn)位于: modeling_whisper.py#L1627) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token # 這里是調用了 WhisperForConditionalGeneration 中的forward函數。這是因為 PyTorch 的 nn.Module 基類定義了一個 __call__ 方法,當你調用模型實例(即 self)時,它會自動調用這個 __call__ 方法,而這個 __call__ 方法又會調用 forward 方法。 outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # 得到下一個token的logits next_token_logits = outputs.logits[:, -1, :] # pre-process distribution 得到其score next_tokens_scores = logits_processor(input_ids, next_token_logits) # argmax :使用argmax 獲取對應的 tokens next_tokens = torch.argmax(next_tokens_scores, dim=-1) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # 判斷是否結束search: # if eos_token was found in one sentence, set sentence to finished # stop if we exceed the maximum length

通過上面的代碼概要,我們就可以知道generate函數進行了很多的設置以后,會調用 greedy_search() 函數來進行文本產生的實際處理。
到此為止,我們就已經對整個的代碼結構了解了。

下面我們通過幾個問題,來回顧一下對代碼的理解。

代碼修改建議

針對上面的問題7,如果要對generate或者其他部分進行修改,建議在 models/whisper的目錄下對父類函數進行重構。
比如,如果要對greedy_search功能進行調整來實現(xiàn)一些獨特的功能時,可以在modeling_whisper.py中重構 greedy_search(),具體做法可以是:

  1. 將 utils.py 中的 greedy_search 函數拷貝到 modeling_whisper.py 文件中。
  2. 需要import 一些必要的庫文件。(具體的庫,可以根據運行時的錯誤提示確定)
  3. 在greedy_search函數中進行修改,來實現(xiàn)想要的功能。

函數在子類中被重新實現(xiàn)之后,調用時,將會優(yōu)先調用新重構的函數。這樣既實現(xiàn)了自己獨特的功能,還不影響其他的模型的運行。

參考文獻

  1. 【基本概念】https://huggingface.co/blog/how-to-generate
  2. https://huggingface.co/docs/transformers/main_classes/text_generation
  3. https://huggingface.co/docs/transformers/internal/generation_utils

文章轉載自: Transformers Generate 功能介紹

上一篇:

如何使用python和django構建后端rest api

下一篇:

18種最佳 RAG 技術
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費