torch.nn.GRU()函数解读
参考链接代码示例一个序列时:>>> import torch.nn as nn>>> gru = nn.GRU(input_size=50, hidden_size=50, batch_first=True)>>> embed = nn.Embedding(3, 50)>>> x = torch.LongTen...
·
参考链接
- 代码示例
一个序列时:
>>> import torch.nn as nn
>>> gru = nn.GRU(input_size=50, hidden_size=50, batch_first=True)
>>> embed = nn.Embedding(3, 50)
>>> x = torch.LongTensor([[0, 1, 2]])
>>> x_embed = embed(x)
>>> x.size()
torch.Size([1, 3])
>>> x_embed.size()
torch.Size([1, 3, 50])
>>> out, hidden = gru(x_embed)
>>> out.size()
torch.Size([1, 3, 50])
>>> hidden.size()
torch.Size([1, 1, 50])
两个示例时:
>>> x = torch.LongTensor([[0, 1, 2], [0, 1, 2]])
>>> x_embed = embed(x)
>>> x_embed.size()
torch.Size([2, 3, 50])
>>> out, hidden = gru(x_embed)
>>> out.size()
torch.Size([2, 3, 50])
>>> hidden.size()
torch.Size([1, 2, 50])
嵌入时:
>>> x = torch.LongTensor([[0, 1, 2], [0, 1, 2]])
>>> x_embed = embed(x)
>>> out1, hidden = gru(x_embed)
>>> out1.size()
torch.Size([2, 3, 50])
>>> hidden.size()
torch.Size([1, 2, 50])
>>> out2, hidden = gru(x_embed, hidden)
>>> out.size()
torch.Size([2, 3, 50])
>>> out2.size()
torch.Size([2, 3, 50])
>>> hidden.size()
torch.Size([1, 2, 50])
「智能机器人开发者大赛」官方平台,致力于为开发者和参赛选手提供赛事技术指导、行业标准解读及团队实战案例解析;聚焦智能机器人开发全栈技术闭环,助力开发者攻克技术瓶颈,促进软硬件集成、场景应用及商业化落地的深度研讨。 加入智能机器人开发者社区iRobot Developer,与全球极客并肩突破技术边界,定义机器人开发的未来范式!
更多推荐



所有评论(0)