caffe 源码学习笔记(10) eltwise layer
背景
这个layer和reduce layer有一些相似,就干脆一起看了. 作用是输入至少两个blob,然后对每个blob中的元素所一些运算,最后得到一个blob.
caffe 支持的运算有"PROD","SUM","MAX"三种
顺便提一句,TensorRT支持的要多一些:
1
2enum class ElementWiseOperation : int
3{
4 kSUM = 0, //!< Sum of the two elements.
5 kPROD = 1, //!< Product of the two elements.
6 kMAX = 2, //!< Maximum of the two elements.
7 kMIN = 3, //!< Minimum of the two elements.
8 kSUB = 4, //!< Substract the second element from the first.
9 kDIV = 5, //!< Divide the first element by the second.
10 kPOW = 6 //!< The first element to the power of the second element.
11};
12
proto
1
2message EltwiseParameter {
3 enum EltwiseOp {
4 PROD = 0;
5 SUM = 1;
6 MAX = 2;
7 }
8 optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation
9 repeated float coeff = 2; // blob-wise coefficient for SUM operation
10
11 // Whether to use an asymptotically slower (for >2 inputs) but stabler method
12 // of computing the gradient for the PROD operation. (No effect for SUM op.)
13 optional bool stable_prod_grad = 3 [default = true];
14}
15
16
proto里面的coeff是对于SUM操作,可以给每一个bottom blob一个加权系数, stable_prod_grad是backward用的,不用管.
c++ 实现
1
2代码比较容易看懂,加了一些注释. 有两个地方可以提一下. 一个是PROD和MAX的做法,都是先求前两个,再把得到的结果和后面的blob进行运算.(其实是很自然的操作...似乎也没什么可说的orz)
3
4另外一个是mask这个变量,是在MAX操作时用来标记在哪个bottom blob 取到了最大值,反向传播时要用.
5
6
7
8template <typename Dtype>
9void EltwiseLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
10 const vector<Blob<Dtype>*>& top) {
11 for (int i = 1; i < bottom.size(); ++i) {
12 CHECK(bottom[i]->shape() == bottom[0]->shape());
13 }
14 // check所有的bottom blob的shape都一样. 至少存在两个bottom blob
15 top[0]->ReshapeLike(*bottom[0]);
16 // If max operation, we will initialize the vector index part.
17 if (this->layer_param_.eltwise_param().operation() ==
18 EltwiseParameter_EltwiseOp_MAX && top.size() == 1) {
19 max_idx_.Reshape(bottom[0]->shape());
20 }
21}
22
23template <typename Dtype>
24void EltwiseLayer<Dtype>::Forward_cpu(
25 const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
26 int* mask = NULL;
27 const Dtype* bottom_data_a = NULL;
28 const Dtype* bottom_data_b = NULL;
29 const int count = top[0]->count();
30 Dtype* top_data = top[0]->mutable_cpu_data();
31 switch (op_) {
32 case EltwiseParameter_EltwiseOp_PROD:
33 caffe_mul(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(), top_data);
34 for (int i = 2; i < bottom.size(); ++i) {
35 caffe_mul(count, top_data, bottom[i]->cpu_data(), top_data);
36 }
37 // 先算前两个,然后把结果和后面的每一个blob(如果还有的话)做运算
38 break;
39 case EltwiseParameter_EltwiseOp_SUM:
40 caffe_set(count, Dtype(0), top_data);
41 // 初始化top data为0
42 // TODO(shelhamer) does BLAS optimize to sum for coeff = 1?
43 for (int i = 0; i < bottom.size(); ++i) {
44 caffe_axpy(count, coeffs_[i], bottom[i]->cpu_data(), top_data);
45 }
46 break;
47 // mask干啥用的???
48 // forward应该用不到,是backward求梯度需要知道在哪个位置得到了最大值
49 case EltwiseParameter_EltwiseOp_MAX:
50 // Initialize
51 mask = max_idx_.mutable_cpu_data();
52 caffe_set(count, -1, mask);
53 caffe_set(count, Dtype(-FLT_MAX), top_data);
54 // bottom 0 & 1
55 bottom_data_a = bottom[0]->cpu_data();
56 bottom_data_b = bottom[1]->cpu_data();
57 for (int idx = 0; idx < count; ++idx) {
58 if (bottom_data_a[idx] > bottom_data_b[idx]) {
59 top_data[idx] = bottom_data_a[idx]; // maxval
60 mask[idx] = 0; // maxid
61 } else {
62 top_data[idx] = bottom_data_b[idx]; // maxval
63 mask[idx] = 1; // maxid
64 }
65 }
66 // bottom 2++
67 for (int blob_idx = 2; blob_idx < bottom.size(); ++blob_idx) {
68 bottom_data_b = bottom[blob_idx]->cpu_data();
69 for (int idx = 0; idx < count; ++idx) {
70 if (bottom_data_b[idx] > top_data[idx]) {
71 top_data[idx] = bottom_data_b[idx]; // maxval
72 mask[idx] = blob_idx; // maxid
73 }
74 }
75 }
76 break;
77 default:
78 LOG(FATAL) << "Unknown elementwise operation.";
79 }
80}
81
Posts in this Series
- caffe 源码阅读笔记
- [施工中]caffe 源码学习笔记(11) softmax
- caffe 源码学习笔记(11) argmax layer
- caffe 源码学习笔记(10) eltwise layer
- caffe 源码学习笔记(9) reduce layer
- caffe 源码学习笔记(8) loss function
- caffe 源码学习笔记(7) slice layer
- caffe 源码学习笔记(6) reshape layer
- caffe 源码学习笔记(5) 卷积
- caffe 源码学习笔记(4) 激活函数
- caffe 源码学习笔记(3) Net
- caffe 源码学习笔记(2) Layer
- caffe 源码学习笔记(1) Blob