caffe 源码学习笔记(11) argmax layer
背景
似乎没什么背景,继续看caffe代码
argmax的作用是返回一个blob某个维度或者batch_size之后的维度的top_k的index(或者pair(index,value))
proto
还是先看proto
1
2message ArgMaxParameter {
3 // If true produce pairs (argmax, maxval)
4 optional bool out_max_val = 1 [default = false];
5 optional uint32 top_k = 2 [default = 1];
6 // The axis along which to maximise -- may be negative to index from the
7 // end (e.g., -1 for the last axis).
8 // By default ArgMaxLayer maximizes over the flattened trailing dimensions
9 // for each index of the first / num dimension.
10 optional int32 axis = 3;
11}
12
13
out_max_val为真表示输出(index,val)的pair,否则只输出index(?,存疑)
top_k应该是要取最大的top k个元素
axis是要求最大的维度,默认情况是把batch_size之后的维度flatten之后求argmax.
c++ 实现
先看Reshape的部分
1
2template <typename Dtype>
3void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
4 const vector<Blob<Dtype>*>& top) {
5 int num_top_axes = bottom[0]->num_axes();
6 if ( num_top_axes < 3 ) num_top_axes = 3;
7 std::vector<int> shape(num_top_axes, 1);
8 if (has_axis_) {
9 // Produces max_ind or max_val per axis
10 shape = bottom[0]->shape();
11 shape[axis_] = top_k_;
12 // axis非默认参数的case: 只有求max的那个维度会变,其他都不变
13 // 问题: out_max_val似乎只适用在axis为默认参数的情况?
14 } else {
15 shape[0] = bottom[0]->shape(0);
16 // Produces max_ind
17 shape[2] = top_k_;
18 // 不是只拿到第top_k,而是拿到top_k的k个结果
19 if (out_max_val_) {
20 // Produces max_ind and max_val
21 shape[1] = 2;
22 }
23 // 默认axis参数得到的top blob的shape 为(batch_size,1或者2,top_k)
24 // 因为会把batch后面的维度flatten 然后求max
25 }
26 top[0]->Reshape(shape);
27}
28
添加了一些注释. 有一个疑问是,axis和out_max_val_这两个参数似乎不支持同时处理
继续看forward
1
2template <typename Dtype>
3void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
4 const vector<Blob<Dtype>*>& top) {
5 const Dtype* bottom_data = bottom[0]->cpu_data();
6 Dtype* top_data = top[0]->mutable_cpu_data();
7 int dim, axis_dist;
8 if (has_axis_) {
9 dim = bottom[0]->shape(axis_);
10 // dim表示做argmax的维度一共有多少个值
11 // Distance between values of axis in blob
12 axis_dist = bottom[0]->count(axis_) / dim;
13 // 因为可能不在最末尾的维度做argmax,因此值在内存中未必是连续的
14 } else {
15 dim = bottom[0]->count(1);
16 // 从batch_size之后的维度数到最后
17 axis_dist = 1;
18 // 把末尾的几个维度flatten之后做argmax,在内存上这些值是连续的,因此axis_dist是1
19 }
20 int num = bottom[0]->count() / dim;
21 std::vector<std::pair<Dtype, int> > bottom_data_vector(dim);
22 for (int i = 0; i < num; ++i) {
23 for (int j = 0; j < dim; ++j) {
24 bottom_data_vector[j] = std::make_pair(
25 bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
26 }
27 // 通过axis_dist控制,把要做argmax的元素从内存中不连续的位置传到一个连续的vector中,目的是做sort
28 std::partial_sort(
29 bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
30 bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
31 // 使得前top_k是最大的top_k个元素,后面的元素顺序任意
32 for (int j = 0; j < top_k_; ++j) {
33 if (out_max_val_) {
34 if (has_axis_) {
35 // Produces max_val per axis
36 top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
37 = bottom_data_vector[j].first;
38 //这个地方感觉有点问题... 就算是有axis参数不支持out_max_val... 输出的不也应该是index吗?
39 } else {
40 // Produces max_ind and max_val
41 top_data[2 * i * top_k_ + j] = bottom_data_vector[j].second;
42 top_data[2 * i * top_k_ + top_k_ + j] = bottom_data_vector[j].first;
43 }
44 } else {
45 // Produces max_ind per axis
46 top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
47 = bottom_data_vector[j].second;
48 }
49 }
50 }
51}
52
53
54
值得注意的是,由于可能不在末尾的维度求max,因此求max的值可能在内存上是不连续的, 注意看axis_dist这个变量.表示的就是要求argmax的相邻元素在内存中的距离.
然后接下来的代码有些让人困惑... 即使是不同时支持out_max_val和axis这两个参数...只有一个输出,那么输出的不也应该也是index吗?
这个输出好像不是很对啊????? 去官方的caffe确认了一下,也是这样写的.
不是很确定这是不是预期的行为.
update:
看来不止我发现了这个问题 some doubts about argmax_layer (may be bug) 可惜caffe看起来已经没人维护了2333
Posts in this Series
- caffe 源码阅读笔记
- [施工中]caffe 源码学习笔记(11) softmax
- caffe 源码学习笔记(11) argmax layer
- caffe 源码学习笔记(10) eltwise layer
- caffe 源码学习笔记(9) reduce layer
- caffe 源码学习笔记(8) loss function
- caffe 源码学习笔记(7) slice layer
- caffe 源码学习笔记(6) reshape layer
- caffe 源码学习笔记(5) 卷积
- caffe 源码学习笔记(4) 激活函数
- caffe 源码学习笔记(3) Net
- caffe 源码学习笔记(2) Layer
- caffe 源码学习笔记(1) Blob