本例包含reddit论坛数据集,使用rnn对论坛留言进行情感分类。是rnn入门的简单易学教程。
代码片段和文件信息
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2017-10-07 02:14 rnn-tutorial-rnnlm-master
文件 29 2017-10-07 02:14 rnn-tutorial-rnnlm-master.gitignore
文件 11358 2017-10-07 02:14 rnn-tutorial-rnnlm-masterLICENSE
文件 64 2017-10-07 02:14 rnn-tutorial-rnnlm-masterNOTICE
文件 2008 2017-10-07 02:14 rnn-tutorial-rnnlm-masterREADME.md
文件 43265 2017-10-07 02:14 rnn-tutorial-rnnlm-masterRNNLM.ipynb
目录 0 2017-10-07 02:14 rnn-tutorial-rnnlm-masterdata
文件 7610868 2017-10-07 02:14 rnn-tutorial-rnnlm-masterdata
eddit-comments-2015-08.csv
文件 3210520 2017-10-07 02:14 rnn-tutorial-rnnlm-masterdata rained-model-theano.npz
文件 773 2017-10-07 02:14 rnn-tutorial-rnnlm-master
equirements.txt
文件 5391 2017-10-07 02:14 rnn-tutorial-rnnlm-master
nn_theano.py
文件 3965 2017-10-07 02:14 rnn-tutorial-rnnlm-master rain-theano.py
文件 693 2017-10-07 02:14 rnn-tutorial-rnnlm-masterutils.py
import numpy as np
import theano as theano
import theano.tensor as T
from utils import *
import operator
class RNNTheano:
def __init__(self word_dim hidden_dim=100 bptt_truncate=4):
# Assign instance variables
self.word_dim = word_dim
self.hidden_dim = hidden_dim
self.bptt_truncate = bptt_truncate
# Randomly initialize the network parameters
U = np.random.uniform(-np.sqrt(1./word_dim) np.sqrt(1./word_dim) (hidden_dim word_dim))
V = np.random.uniform(-np.sqrt(1./hidden_dim) np.sqrt(1./hidden_dim) (word_dim hidden_dim))
W = np.random.uniform(-np.sqrt(1./hidden_dim) np.sqrt(1./hidden_dim) (hidden_dim hidden_dim))
# Theano: Created shared variables
self.U = theano.shared(name=‘U‘ value=U.astype(theano.config.floatX))
self.V = theano.shared(name=‘V‘ value=V.astype(theano.config.floatX))
self.W = theano.shared(name=‘W‘ value=W.astype(theano.config.floatX))
# We store the Theano graph here
self.theano = {}
self.__theano_build__()
def __theano_build__(self):
U V W = self.U self.V self.W
x = T.ivector(‘x‘)
y = T.ivector(‘y‘)
def forward_prop_step(x_t s_t_prev U V W):
s_t = T.tanh(U[:x_t] + W.dot(s_t_prev))
o_t = T.nnet.softmax(V.dot(s_t))
return [o_t[0] s_t]
[os] updates = theano.scan(
forward_prop_step
sequences=x
outputs_info=[None dict(initial=T.zeros(self.hidden_dim))]
non_sequences=[U V W]
truncate_gradient=self.bptt_truncate
strict=True)
prediction = T.argmax(o axis=1)
o_error = T.sum(T.nnet.categorical_crossentropy(o y))
# Gradients
dU = T.grad(o_error U)
dV = T.grad(o_error V)
dW = T.grad(o_error W)
# Assign functions
self.forward_propagation = theano.function([x] o)
self.predict = theano.function([x] prediction)
self.ce_error = theano.function([x y] o_error)
self.bptt = theano.function([x y] [dU dV dW])
# SGD
learning_rate = T.scalar(‘learning_rate‘)
self.sgd_step = theano.function([xylearning_rate] []
updates=[(self.U self.U - learning_rate * dU)
(self.V self.V - learning_rate * dV)
(self.W self.W - learning_rate * dW)])
def calculate_total_loss(self X Y):
return np.sum([self.ce_error(xy) for xy in zip(XY)])
def calculate_loss(self X Y):
# Divide calculate_loss by the number of words
num_words = np.sum([len(y) for y in Y])
return self.calculate_total_loss(XY)/float(num_words)
def gradient_check_theano(model x y h=0.001 error_threshold=0.01):
# Overwrite the bptt attribute. We need to backpropagate all the
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2017-10-07 02:14 rnn-tutorial-rnnlm-master
文件 29 2017-10-07 02:14 rnn-tutorial-rnnlm-master.gitignore
文件 11358 2017-10-07 02:14 rnn-tutorial-rnnlm-masterLICENSE
文件 64 2017-10-07 02:14 rnn-tutorial-rnnlm-masterNOTICE
文件 2008 2017-10-07 02:14 rnn-tutorial-rnnlm-masterREADME.md
文件 43265 2017-10-07 02:14 rnn-tutorial-rnnlm-masterRNNLM.ipynb
目录 0 2017-10-07 02:14 rnn-tutorial-rnnlm-masterdata
文件 7610868 2017-10-07 02:14 rnn-tutorial-rnnlm-masterdata
eddit-comments-2015-08.csv
文件 3210520 2017-10-07 02:14 rnn-tutorial-rnnlm-masterdata rained-model-theano.npz
文件 773 2017-10-07 02:14 rnn-tutorial-rnnlm-master
equirements.txt
文件 5391 2017-10-07 02:14 rnn-tutorial-rnnlm-master
nn_theano.py
文件 3965 2017-10-07 02:14 rnn-tutorial-rnnlm-master rain-theano.py
文件 693 2017-10-07 02:14 rnn-tutorial-rnnlm-masterutils.py
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件举报,一经查实,本站将立刻删除。
评论列表(条)