博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Theano中如何只更新一部分权重,用法及理由。
阅读量:4068 次
发布时间:2019-05-25

本文共 1552 字,大约阅读时间需要 5 分钟。

如果你想只更新权重矩阵的一个子集(例如一些行或列)。这种情况下神经网络只利用这个子集来进行前向传播,那么成本函数(cost function),只应取决于迭代过程中使用的权重的子集。

举一个自然语言处理中的例子:

例如,如果你想学习一个查找表(lookup table),来做词嵌入,每一行是一个词向量。在每个迭代中,更新的参数应该只是在向前传播中使用到的那些行。对应于theano函数,为查找表定义一个共享变量:

lookup_table = theano.shared(matrix_ndarray)

通过一个整数向量索引获取表的一个子集。

subset = lookup_table[vector_of_indices]

从现在开始,只使用subset,不使用lookup_table[vector_of_indices]。后者在求grad时可能创建新变量。

定义成本函数:

cost = something that depends on subsetg = theano.grad(cost, subset)

为更新参数有两种方法:要么使用inc_subtensorset_subtensor。推荐使用inc_subtensor

updates = inc_subtensor(subset, g*lr)

OR
updates = set_subtensor(subset, subset + g*lr)

在这里:

f = theano.function(..., updates=[(lookup_table, updates)])

注意:++划重点!!!++ 上面只计算了subset的grad。也可以计算整个lookup_table的grad,区别就是计算得到的梯度gradient只在用来前向传播的那些行为非零值。如果用随机梯度下降算法SGD来更新参数的话,除了额外的计算外没有其他区别,只是其中有一些梯度为0(那些在前向传播中没有用到的行,即词向量)。但是,如果你想使用一个不同的优化方法如rmsprop或Hessian-Free优化,就会有问题。

  • In rmsprop, you keep an exponentially decaying squared gradient by whose square root you divide the current gradient to rescale the update step component-wise. If the gradient of the lookup table row which corresponds to a rare word is very often zero, the squared gradient history will tend to zero for that row because the history of that row decays towards zero.
  • Using Hessian-Free, you will get many zero rows and columns. Even one of them would make it non-invertible.

小结: In general, it would be better to compute the gradient only w.r.t. to those lookup table rows or columns which are actually used during the forward propagation.

另: 其他深度学习优化方法近期会专门写一篇博客介绍。

转载地址:http://jnoji.baihongyu.com/

你可能感兴趣的文章
【leetcode】Candy(python)
查看>>
【leetcode】Sum Root to leaf Numbers
查看>>
【leetcode】Pascal's Triangle II (python)
查看>>
java自定义容器排序的两种方法
查看>>
如何成为编程高手
查看>>
本科生的编程水平到底有多高
查看>>
AngularJS2中最基本的文件说明
查看>>
从头开始学习jsp(2)——jsp的基本语法
查看>>
使用与或运算完成两个整数的相加
查看>>
备忘:java中的递归
查看>>
Solr及Spring-Data-Solr入门学习
查看>>
python_time模块
查看>>
python_configparser(解析ini)
查看>>
selenium学习资料
查看>>
<转>文档视图指针互获
查看>>
从mysql中 导出/导入表及数据
查看>>
HQL语句大全(转)
查看>>
几个常用的Javascript字符串处理函数 spilt(),join(),substring()和indexof()
查看>>
javascript传参字符串 与引号的嵌套调用
查看>>
swiper插件的的使用
查看>>