包阅导读总结
1. 关键词:RecurrentGemma、DecoderLayer、Linear、RMSNorm、Mlp
2. 总结:文本主要介绍了 RecurrentGemma 架构中的 RecurrentGemmaForCausalLM 模型,包括嵌入层、多层解码器层的结构组成以及最后的线性层。
3. 主要内容:
– RecurrentGemmaForCausalLM 模型:
– 嵌入层:Embedding(256000, 2560, padding_idx=0)
– 解码器层:
– 包含多种类型的解码器层,如 RecurrentGemmaDecoderLayer 中的 RecurrentGemmaRecurrentBlock 和 RecurrentGemmaSdpaAttention 等。
– 每个解码器层包含时间和通道的预归一化、不同的线性层、卷积层、注意力机制等。
– 最终归一化层:RecurrentGemmaRMSNorm()
– 线性层:lm_head 用于最终输出
– 解码器层中的模块:
– 如 RecurrentGemmaRecurrentBlock 中的线性层、卷积层等。
– 如 RecurrentGemmaSdpaAttention 中的投影层和旋转嵌入。
– Mlp 中的不同线性层和激活函数。
思维导图:
文章地址:https://developers.googleblog.com/en/gemma-explained-recurrentgemma-architecture/
文章来源:developers.googleblog.com
作者:Ju-yeong Ji,Ravin Kumar
发布时间:2024/8/29 0:00
语言:英文
总字数:1235字
预计阅读时间:5分钟
评分:90分
标签:RecurrentGemma,AI 模型,混合模型,RNN,注意力机制
以下为原文内容
本内容来源于用户推荐转载,旨在分享知识与观点,如有侵权请联系删除 联系邮箱 media@ilingban.com
RecurrentGemmaForCausalLM( (model): RecurrentGemmaModel( (embed_tokens): Embedding(256000, 2560, padding_idx=0) (layers): ModuleList( (0-1): 2 x RecurrentGemmaDecoderLayer( (temporal_pre_norm): RecurrentGemmaRMSNorm() (temporal_block): RecurrentGemmaRecurrentBlock( (linear_y): Linear(in_features=2560, out_features=2560, bias=True) (linear_x): Linear(in_features=2560, out_features=2560, bias=True) (linear_out): Linear(in_features=2560, out_features=2560, bias=True) (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560) (rg_lru): RecurrentGemmaRglru() (act_fn): PytorchGELUTanh() ) (channel_pre_norm): RecurrentGemmaRMSNorm() (mlp_block): RecurrentGemmaMlp( (gate_proj): Linear(in_features=2560, out_features=7680, bias=True) (up_proj): Linear(in_features=2560, out_features=7680, bias=True) (down_proj): Linear(in_features=7680, out_features=2560, bias=True) (act_fn): PytorchGELUTanh() ) ) (2): RecurrentGemmaDecoderLayer( (temporal_pre_norm): RecurrentGemmaRMSNorm() (temporal_block): RecurrentGemmaSdpaAttention( (q_proj): Linear(in_features=2560, out_features=2560, bias=False) (k_proj): Linear(in_features=2560, out_features=256, bias=False) (v_proj): Linear(in_features=2560, out_features=256, bias=False) (o_proj): Linear(in_features=2560, out_features=2560, bias=True) (rotary_emb): RecurrentGemmaRotaryEmbedding() ) (channel_pre_norm): RecurrentGemmaRMSNorm() (mlp_block): RecurrentGemmaMlp( (gate_proj): Linear(in_features=2560, out_features=7680, bias=True) (up_proj): Linear(in_features=2560, out_features=7680, bias=True) (down_proj): Linear(in_features=7680, out_features=2560, bias=True) (act_fn): PytorchGELUTanh() ) ) : (23): RecurrentGemmaDecoderLayer( (temporal_pre_norm): RecurrentGemmaRMSNorm() (temporal_block): RecurrentGemmaSdpaAttention( (q_proj): Linear(in_features=2560, out_features=2560, bias=False) (k_proj): Linear(in_features=2560, out_features=256, bias=False) (v_proj): Linear(in_features=2560, out_features=256, bias=False) (o_proj): Linear(in_features=2560, out_features=2560, bias=True) (rotary_emb): RecurrentGemmaRotaryEmbedding() ) (channel_pre_norm): RecurrentGemmaRMSNorm() (mlp_block): RecurrentGemmaMlp( (gate_proj): Linear(in_features=2560, out_features=7680, bias=True) (up_proj): Linear(in_features=2560, out_features=7680, bias=True) (down_proj): Linear(in_features=7680, out_features=2560, bias=True) (act_fn): PytorchGELUTanh() ) ) (24-25): 2 x RecurrentGemmaDecoderLayer( (temporal_pre_norm): RecurrentGemmaRMSNorm() (temporal_block): RecurrentGemmaRecurrentBlock( (linear_y): Linear(in_features=2560, out_features=2560, bias=True) (linear_x): Linear(in_features=2560, out_features=2560, bias=True) (linear_out): Linear(in_features=2560, out_features=2560, bias=True) (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560) (rg_lru): RecurrentGemmaRglru() (act_fn): PytorchGELUTanh() ) (channel_pre_norm): RecurrentGemmaRMSNorm() (mlp_block): RecurrentGemmaMlp( (gate_proj): Linear(in_features=2560, out_features=7680, bias=True) (up_proj): Linear(in_features=2560, out_features=7680, bias=True) (down_proj): Linear(in_features=7680, out_features=2560, bias=True) (act_fn): PytorchGELUTanh() ) ) ) (final_norm): RecurrentGemmaRMSNorm() ) (lm_head): Linear(in_features=2560, out_features=256000, bias=False))