caffe 源码学习笔记(6) reshape layer
背景
最近在魔改 tensorRT 的caffe parser 之前caffe模型转到trt模型时,有一个修改是需要将reshape layer的param末尾补1,比较繁琐,于是看了下caffe的reshape layer的实现.
proto
1
2message ReshapeParameter {
3 // Specify the output dimensions. If some of the dimensions are set to 0,
4 // the corresponding dimension from the bottom layer is used (unchanged).
5 // Exactly one dimension may be set to -1, in which case its value is
6 // inferred from the count of the bottom blob and the remaining dimensions.
7 // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8:
8 //
9 // layer {
10 // type: "Reshape" bottom: "input" top: "output"
11 // reshape_param { ... }
12 // }
13 //
14 // If "input" is 2D with shape 2 x 8, then the following reshape_param
15 // specifications are all equivalent, producing a 3D blob "output" with shape
16 // 2 x 2 x 4:
17 //
18 // reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
19 // reshape_param { shape { dim: 0 dim: 2 dim: 4 } }
20 // reshape_param { shape { dim: 0 dim: 2 dim: -1 } }
21 // reshape_param { shape { dim: 0 dim:-1 dim: 4 } }
22 //
23 optional BlobShape shape = 1;
24
25 // axis and num_axes control the portion of the bottom blob's shape that are
26 // replaced by (included in) the reshape. By default (axis == 0 and
27 // num_axes == -1), the entire bottom blob shape is included in the reshape,
28 // and hence the shape field must specify the entire output shape.
29 //
30 // axis may be non-zero to retain some portion of the beginning of the input
31 // shape (and may be negative to index from the end; e.g., -1 to begin the
32 // reshape after the last axis, including nothing in the reshape,
33 // -2 to include only the last axis, etc.).
34 //
35 // For example, suppose "input" is a 2D blob with shape 2 x 8.
36 // Then the following ReshapeLayer specifications are all equivalent,
37 // producing a blob "output" with shape 2 x 2 x 4:
38 //
39 // reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
40 // reshape_param { shape { dim: 2 dim: 4 } axis: 1 }
41 // reshape_param { shape { dim: 2 dim: 4 } axis: -3 }
42 //
43 // num_axes specifies the extent of the reshape.
44 // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on
45 // input axes in the range [axis, axis+num_axes].
46 // num_axes may also be -1, the default, to include all remaining axes
47 // (starting from axis).
48 //
49 // For example, suppose "input" is a 2D blob with shape 2 x 8.
50 // Then the following ReshapeLayer specifications are equivalent,
51 // producing a blob "output" with shape 1 x 2 x 8.
52 //
53 // reshape_param { shape { dim: 1 dim: 2 dim: 8 } }
54 // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 }
55 // reshape_param { shape { dim: 1 } num_axes: 0 }
56 //
57 // On the other hand, these would produce output blob shape 2 x 1 x 8:
58 //
59 // reshape_param { shape { dim: 2 dim: 1 dim: 8 } }
60 // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 }
61 //
62 optional int32 axis = 2 [default = 0];
63 optional int32 num_axes = 3 [default = -1];
64}
65
emmm,是不是稍微复杂了点.. 其实主要复杂在两个可选参数axis和num_axes上. 如果不考虑这两个参数,那么 reshape的维度只有两点需要注意.一个是0表示该维度不变,一个是-1表示该维度是需要推断出来.
0 means “copy the respective dimension of the bottom layer”. That is, if the bottom has 2 as its 1st dimension, the top will have 2 as its 1st dimension as well, given dim: 0 as the 1st target dimension.
-1 stands for “infer this from the other dimensions”. This behavior is similar to that of -1 in numpy’s or [] for MATLAB’s reshape: this dimension is calculated to keep the overall element count the same as in the bottom layer. At most one -1 can be used in a reshape operation.
然后axis和num_axes两个参数可以一起看.
其实就是表示只对输入维度的[axis, axis+num_axes]做reshape,其他维护维持现状.
不过axis的使用例子写错了,所以弄得有些费解,还是看了代码才弄清楚.给caffe提了个pr fix an error of axis parameter in the example of ReshapeParameter #6936 能不能merge随缘吧2333
然后还有两个case,其一是num_axes的默认情况,表示要处理"all remaining axes" 另外一个是axis为负数,此时不使用num_axes 参数
值得一提的是
specifying reshape_param { shape { dim: 0 dim: -1 } } makes the layer behave in exactly the same way as the Flatten layer.
c++实现
先看 LayerSetUp. 我们似乎很少关注layer的这部分..原因是大部分layer这部分其实都没什么好关注的
1
2template <typename Dtype>
3void ReshapeLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
4 const vector<Blob<Dtype>*>& top) {
5 CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not "
6 "allow in-place computation.";
7 inferred_axis_ = -1;
8 copy_axes_.clear();
9 const BlobShape& top_blob_shape = this->layer_param_.reshape_param().shape();
10 const int top_num_axes = top_blob_shape.dim_size();
11 constant_count_ = 1;
12 for (int i = 0; i < top_num_axes; ++i) {
13 const int top_dim = top_blob_shape.dim(i);
14 if (top_dim == 0) {
15 copy_axes_.push_back(i);
16 } else if (top_dim == -1) {
17 CHECK_EQ(inferred_axis_, -1) << "new shape contains multiple "
18 << "-1 dims; at most a single (1) value of -1 may be specified";
19 inferred_axis_ = i;
20 } else {
21 constant_count_ *= top_dim;
22 }
23 }
24}
25
特殊处理了dim为0和-1的情况,然后把需要变换的维度count放在constant_count_,盲猜是之后做推断用.
接下来我们看下Reshape
1template <typename Dtype>
2void ReshapeLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
3 const vector<Blob<Dtype>*>& top) {
4 const int input_start_axis = this->layer_param_.reshape_param().axis();
5 const int start_axis = (input_start_axis >= 0) ? input_start_axis :
6 bottom[0]->num_axes() + input_start_axis + 1;
7 CHECK_GE(start_axis, 0) << "axis " << input_start_axis << " out of range";
8 CHECK_LE(start_axis, bottom[0]->num_axes()) << "axis " << input_start_axis
9 << " out of range for " << bottom[0]->num_axes() << "-D input blob";
10 const int num_axes = this->layer_param_.reshape_param().num_axes();
11 CHECK_GE(num_axes, -1) << "num_axes must be >= 0, or -1 for all";
12 const int end_axis =
13 (num_axes == -1) ? bottom[0]->num_axes() : (start_axis + num_axes);
14 CHECK_LE(end_axis, bottom[0]->num_axes())
15 << "end_axis = axis + num_axes is out of range";
16 const int num_axes_replaced = end_axis - start_axis;
17 const int num_axes_retained = bottom[0]->num_axes() - num_axes_replaced;
18 const BlobShape& top_blob_shape = this->layer_param_.reshape_param().shape();
19 const int num_new_axes = top_blob_shape.dim_size();
20 vector<int> top_shape(num_axes_retained + num_new_axes);
21 int top_shape_index = 0;
22 for (int i = 0; i < start_axis; ++i) {
23 top_shape[top_shape_index++] = bottom[0]->shape(i);
24 }
25 for (int i = 0; i < num_new_axes; ++i) {
26 top_shape[top_shape_index++] = top_blob_shape.dim(i);
27 }
28 for (int i = end_axis; i < bottom[0]->num_axes(); ++i) {
29 top_shape[top_shape_index++] = bottom[0]->shape(i);
30 }
31 CHECK_EQ(top_shape_index, top_shape.size());
32 for (int i = 0; i < copy_axes_.size(); ++i) {
33 const int copy_axis_index = copy_axes_[i];
34 CHECK_GT(bottom[0]->num_axes(), start_axis + copy_axis_index)
35 << "new shape contains a 0, but there was no corresponding bottom axis "
36 << "to copy";
37 top_shape[start_axis + copy_axis_index] =
38 bottom[0]->shape(start_axis + copy_axis_index);
39 }
40 if (inferred_axis_ >= 0) {
41 // A -1 dim was specified; infer the correct dimension by computing the
42 // product of the other dimensions.
43 int explicit_count = constant_count_;
44 explicit_count *= bottom[0]->count(0, start_axis);
45 explicit_count *= bottom[0]->count(end_axis);
46 for (int i = 0; i < copy_axes_.size(); ++i) {
47 const int copy_axis_index = copy_axes_[i];
48 explicit_count *= top_shape[start_axis + copy_axis_index];
49 }
50 CHECK_EQ(0, bottom[0]->count() % explicit_count) << "bottom count ("
51 << bottom[0]->count() << ") must be divisible by the product of "
52 << "the specified dimensions (" << explicit_count << ")";
53 const int inferred_dim = bottom[0]->count() / explicit_count;
54 top_shape[start_axis + inferred_axis_] = inferred_dim;
55 }
56 top[0]->Reshape(top_shape);
57 CHECK_EQ(top[0]->count(), bottom[0]->count())
58 << "output count must match input count";
59 top[0]->ShareData(*bottom[0]);
60 top[0]->ShareDiff(*bottom[0]);
61}
62
63INSTANTIATE_CLASS(ReshapeLayer);
64REGISTER_LAYER_CLASS(Reshape);
65
66} // namespace caffe
67
代码似乎有些长,实际上很简单.后半部分是推断维度的,前半部分也很直观,就是做了比较多的check.
然后Reshape Layer是没有Forward函数的,因为没有做任何计算,只是改变了blob的reshape,也不存在数据的拷贝.
Posts in this Series
- caffe 源码阅读笔记
- [施工中]caffe 源码学习笔记(11) softmax
- caffe 源码学习笔记(11) argmax layer
- caffe 源码学习笔记(10) eltwise layer
- caffe 源码学习笔记(9) reduce layer
- caffe 源码学习笔记(8) loss function
- caffe 源码学习笔记(7) slice layer
- caffe 源码学习笔记(6) reshape layer
- caffe 源码学习笔记(5) 卷积
- caffe 源码学习笔记(4) 激活函数
- caffe 源码学习笔记(3) Net
- caffe 源码学习笔记(2) Layer
- caffe 源码学习笔记(1) Blob