微信截圖_17435904448874.png)
跟大牛學LLM訓練和使用技巧
scripts
utils
examples : 使用方法,參考案例
...
src/transformers (Transformer相關的代碼)
data: 數據處理
models : 模型的實現(xiàn)代碼,比如BERT, GPT,Whisper模型,都在此目錄下實現(xiàn)
generation : 文本生成相關代碼
...
從上面的代碼中,examples中提供了模型的使用方法的參考例子。
我們的今天介紹的主要內容都在 src/transformers 目錄下,其中 models 目錄下,是基于transformer的各種模型的實現(xiàn)代碼,Generation 包含通用的文本產生的實現(xiàn)代碼。
我們以Whisper 模型為例來詳細介紹一下代碼的結構和調用關系。下面我們以v4.29.1的版本為例進行介紹。
首先,whisper模型的代碼位于:src/transformers/models/whisper 目錄下。其主要功能都封裝在 modeling_whisper.py 文件中。
調用入口:WhisperForConditionalGeneration類
此python文件中包含多個類,繼承的關系比較復雜,它們之間的主要調用關系如下(以greedy search為例):
WhisperForConditionalGeneration (L1312) : 調用入口類
forward() (L1359) --> 細節(jié)在:WhisperModel,WhisperEncoder,WhisperDecoder 類中實現(xiàn)
generate() (L1455) --> 細節(jié)在: generation/utils.py#L1146 中實現(xiàn)
greedy_search(): L2164 --> 調用 search 函數來做實際的處理,比如自回歸處理
Forward函數:
forward函數位于 class transformers.WhisperModel 類中,代碼位置請參考:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1215
def forward():
# Encoder將輸入的語音信號,編碼為聲學信息,也就是 encoder_outputs
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Decoder 的主要輸入為 decoder_input_ids (對應文本) 和 encoder_outputs (對應聲學信息,在翻譯任務中,對應著源語言)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
通過上面,我們可以看到有 Encoder 部分和 Decoder 部分,分別對應聲學特征的提取和文本產生部分。
在 WhisperForConditionalGeneration 類中,也有一個forward函數,是對上面forward函數的封裝。
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1359
其中self.encoder的實現(xiàn)代碼位于 class WhisperEncoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L735
其中self.decoder的實現(xiàn)代碼位于 class WhisperDecoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L881
Generate函數
Generate函數入口:位于 class WhisperForConditionalGeneration(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455
此處只是調用的入口,具體的實現(xiàn)代碼位于 class GenerationMixin 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
def generate() L1146
其中generate函數使用的greedy_search的實現(xiàn)位于:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L2164
下面,我們來進一步了解 generate 的實現(xiàn)代碼,來看看如何對此代碼進行修改。
入口代碼
Generate函數的入口位于: WhisperForConditionalGeneration類中的 def generate 函數
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455
代碼的概要如下,從代碼中可以看到,這個函數主要是進行了一些參數設置,具體的實現(xiàn)是調用了父類中的對應函數來執(zhí)行的。
def generate()
# 參數設置部分
# 調用部分(此處調用了父類中的generate實現(xiàn))
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)
然后,我們可以逐級向上搜索其父類,可以看到
到此為止,我們就可以看到,具體的實現(xiàn)都在 GenerationMixin 類中。
Generate函數實現(xiàn)細節(jié)
下面,我們來看一下 GenerationMixin類中的 generate 函數的實現(xiàn)細節(jié)。
代碼位置:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
其代碼概要如下:
def generate(): L1146
# 根據解碼方式的不同,此函數中有最多14步的處理步驟,我們以greedy search為例
# 1. Handle generation_config
and kwargs that might update it, and validate the .generate()
call
# 2. Set generation parameters if not already defined
# 3. Define model inputs
# 4. Define other model kwargs
# 5. Prepare input_ids
which will be used for auto-regressive generation
# 6. Prepare max_length
depending on other stopping criteria.
# 7. determine generation mode
# 8. prepare distribution pre_processing samplers
# 9. prepare stopping criteria
# 10. go into different generation modes
# 11. run greedy search (L1515)
def greedy_search(): L2164
# 初始化,設置
# 循環(huán)處理
while True: # L2317
# prepare model inputs (下面函數的具體實現(xiàn)位于: modeling_whisper.py#L1627)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
# 這里是調用了 WhisperForConditionalGeneration 中的forward函數。這是因為 PyTorch 的 nn.Module 基類定義了一個 __call__ 方法,當你調用模型實例(即 self)時,它會自動調用這個 __call__ 方法,而這個 __call__ 方法又會調用 forward 方法。
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 得到下一個token的logits
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution 得到其score
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax :使用argmax 獲取對應的 tokens
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# 判斷是否結束search:
# if eos_token was found in one sentence, set sentence to finished
# stop if we exceed the maximum length
通過上面的代碼概要,我們就可以知道generate函數進行了很多的設置以后,會調用 greedy_search() 函數來進行文本產生的實際處理。
到此為止,我們就已經對整個的代碼結構了解了。
下面我們通過幾個問題,來回顧一下對代碼的理解。
針對上面的問題7,如果要對generate或者其他部分進行修改,建議在 models/whisper的目錄下對父類函數進行重構。
比如,如果要對greedy_search功能進行調整來實現(xiàn)一些獨特的功能時,可以在modeling_whisper.py中重構 greedy_search(),具體做法可以是:
函數在子類中被重新實現(xiàn)之后,調用時,將會優(yōu)先調用新重構的函數。這樣既實現(xiàn)了自己獨特的功能,還不影響其他的模型的運行。
文章轉載自: Transformers Generate 功能介紹