caffe 源码学习笔记(6) reshape layer

背景 

最近在魔改 tensorRT 的caffe parser 之前caffe模型转到trt模型时,有一个修改是需要将reshape layer的param末尾补1,比较繁琐,于是看了下caffe的reshape layer的实现.

proto

 1
 2message ReshapeParameter {
 3  // Specify the output dimensions. If some of the dimensions are set to 0,
 4  // the corresponding dimension from the bottom layer is used (unchanged).
 5  // Exactly one dimension may be set to -1, in which case its value is
 6  // inferred from the count of the bottom blob and the remaining dimensions.
 7  // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8:
 8  //
 9  //   layer {
10  //     type: "Reshape" bottom: "input" top: "output"
11  //     reshape_param { ... }
12  //   }
13  //
14  // If "input" is 2D with shape 2 x 8, then the following reshape_param
15  // specifications are all equivalent, producing a 3D blob "output" with shape
16  // 2 x 2 x 4:
17  //
18  //   reshape_param { shape { dim:  2  dim: 2  dim:  4 } }
19  //   reshape_param { shape { dim:  0  dim: 2  dim:  4 } }
20  //   reshape_param { shape { dim:  0  dim: 2  dim: -1 } }
21  //   reshape_param { shape { dim:  0  dim:-1  dim:  4 } }
22  //
23  optional BlobShape shape = 1;
24
25  // axis and num_axes control the portion of the bottom blob's shape that are
26  // replaced by (included in) the reshape. By default (axis == 0 and
27  // num_axes == -1), the entire bottom blob shape is included in the reshape,
28  // and hence the shape field must specify the entire output shape.
29  //
30  // axis may be non-zero to retain some portion of the beginning of the input
31  // shape (and may be negative to index from the end; e.g., -1 to begin the
32  // reshape after the last axis, including nothing in the reshape,
33  // -2 to include only the last axis, etc.).
34  //
35  // For example, suppose "input" is a 2D blob with shape 2 x 8.
36  // Then the following ReshapeLayer specifications are all equivalent,
37  // producing a blob "output" with shape 2 x 2 x 4:
38  //
39  //   reshape_param { shape { dim: 2  dim: 2  dim: 4 } }
40  //   reshape_param { shape { dim: 2  dim: 4 } axis:  1 }
41  //   reshape_param { shape { dim: 2  dim: 4 } axis: -3 }
42  //
43  // num_axes specifies the extent of the reshape.
44  // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on
45  // input axes in the range [axis, axis+num_axes].
46  // num_axes may also be -1, the default, to include all remaining axes
47  // (starting from axis).
48  //
49  // For example, suppose "input" is a 2D blob with shape 2 x 8.
50  // Then the following ReshapeLayer specifications are equivalent,
51  // producing a blob "output" with shape 1 x 2 x 8.
52  //
53  //   reshape_param { shape { dim:  1  dim: 2  dim:  8 } }
54  //   reshape_param { shape { dim:  1  dim: 2  }  num_axes: 1 }
55  //   reshape_param { shape { dim:  1  }  num_axes: 0 }
56  //
57  // On the other hand, these would produce output blob shape 2 x 1 x 8:
58  //
59  //   reshape_param { shape { dim: 2  dim: 1  dim: 8  }  }
60  //   reshape_param { shape { dim: 1 }  axis: 1  num_axes: 0 }
61  //
62  optional int32 axis = 2 [default = 0];
63  optional int32 num_axes = 3 [default = -1];
64}
65

emmm,是不是稍微复杂了点.. 其实主要复杂在两个可选参数axis和num_axes上.  如果不考虑这两个参数,那么 reshape的维度只有两点需要注意.一个是0表示该维度不变,一个是-1表示该维度是需要推断出来.

0 means “copy the respective dimension of the bottom layer”. That is, if the bottom has 2 as its 1st dimension, the top will have 2 as its 1st dimension as well, given dim: 0 as the 1st target dimension.

-1 stands for “infer this from the other dimensions”. This behavior is similar to that of -1 in numpy’s or [] for MATLAB’s reshape: this dimension is calculated to keep the overall element count the same as in the bottom layer. At most one -1 can be used in a reshape operation.

然后axis和num_axes两个参数可以一起看.

其实就是表示只对输入维度的[axis, axis+num_axes]做reshape,其他维护维持现状.

不过axis的使用例子写错了,所以弄得有些费解,还是看了代码才弄清楚.给caffe提了个pr fix an error of axis parameter in the example of ReshapeParameter #6936 能不能merge随缘吧2333

然后还有两个case,其一是num_axes的默认情况,表示要处理"all remaining axes" 另外一个是axis为负数,此时不使用num_axes 参数

值得一提的是

specifying reshape_param { shape { dim: 0 dim: -1 } } makes the layer behave in exactly the same way as the Flatten layer.

c++实现

先看 LayerSetUp. 我们似乎很少关注layer的这部分..原因是大部分layer这部分其实都没什么好关注的

 1
 2template <typename Dtype>
 3void ReshapeLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 4    const vector<Blob<Dtype>*>& top) {
 5  CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not "
 6      "allow in-place computation.";
 7  inferred_axis_ = -1;
 8  copy_axes_.clear();
 9  const BlobShape& top_blob_shape = this->layer_param_.reshape_param().shape();
10  const int top_num_axes = top_blob_shape.dim_size();
11  constant_count_ = 1;
12  for (int i = 0; i < top_num_axes; ++i) {
13    const int top_dim = top_blob_shape.dim(i);
14    if (top_dim == 0) {
15      copy_axes_.push_back(i);
16    } else if (top_dim == -1) {
17      CHECK_EQ(inferred_axis_, -1) << "new shape contains multiple "
18          << "-1 dims; at most a single (1) value of -1 may be specified";
19      inferred_axis_ = i;
20    } else {
21      constant_count_ *= top_dim;
22    }
23  }
24}
25

特殊处理了dim为0和-1的情况,然后把需要变换的维度count放在constant_count_,盲猜是之后做推断用.

接下来我们看下Reshape

 1template <typename Dtype>
 2void ReshapeLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 3    const vector<Blob<Dtype>*>& top) {
 4  const int input_start_axis = this->layer_param_.reshape_param().axis();
 5  const int start_axis = (input_start_axis >= 0) ? input_start_axis :
 6      bottom[0]->num_axes() + input_start_axis + 1;
 7  CHECK_GE(start_axis, 0) << "axis " << input_start_axis << " out of range";
 8  CHECK_LE(start_axis, bottom[0]->num_axes()) << "axis " << input_start_axis
 9      << " out of range for " << bottom[0]->num_axes() << "-D input blob";
10  const int num_axes = this->layer_param_.reshape_param().num_axes();
11  CHECK_GE(num_axes, -1) << "num_axes must be >= 0, or -1 for all";
12  const int end_axis =
13      (num_axes == -1) ? bottom[0]->num_axes() : (start_axis + num_axes);
14  CHECK_LE(end_axis, bottom[0]->num_axes())
15      << "end_axis = axis + num_axes is out of range";
16  const int num_axes_replaced = end_axis - start_axis;
17  const int num_axes_retained = bottom[0]->num_axes() - num_axes_replaced;
18  const BlobShape& top_blob_shape = this->layer_param_.reshape_param().shape();
19  const int num_new_axes = top_blob_shape.dim_size();
20  vector<int> top_shape(num_axes_retained + num_new_axes);
21  int top_shape_index = 0;
22  for (int i = 0; i < start_axis; ++i) {
23    top_shape[top_shape_index++] = bottom[0]->shape(i);
24  }
25  for (int i = 0; i < num_new_axes; ++i) {
26    top_shape[top_shape_index++] = top_blob_shape.dim(i);
27  }
28  for (int i = end_axis; i < bottom[0]->num_axes(); ++i) {
29    top_shape[top_shape_index++] = bottom[0]->shape(i);
30  }
31  CHECK_EQ(top_shape_index, top_shape.size());
32  for (int i = 0; i < copy_axes_.size(); ++i) {
33    const int copy_axis_index = copy_axes_[i];
34    CHECK_GT(bottom[0]->num_axes(), start_axis + copy_axis_index)
35        << "new shape contains a 0, but there was no corresponding bottom axis "
36        << "to copy";
37    top_shape[start_axis + copy_axis_index] =
38        bottom[0]->shape(start_axis + copy_axis_index);
39  }
40  if (inferred_axis_ >= 0) {
41    // A -1 dim was specified; infer the correct dimension by computing the
42    // product of the other dimensions.
43    int explicit_count = constant_count_;
44    explicit_count *= bottom[0]->count(0, start_axis);
45    explicit_count *= bottom[0]->count(end_axis);
46    for (int i = 0; i < copy_axes_.size(); ++i) {
47      const int copy_axis_index = copy_axes_[i];
48      explicit_count *= top_shape[start_axis + copy_axis_index];
49    }
50    CHECK_EQ(0, bottom[0]->count() % explicit_count) << "bottom count ("
51        << bottom[0]->count() << ") must be divisible by the product of "
52        << "the specified dimensions (" << explicit_count << ")";
53    const int inferred_dim = bottom[0]->count() / explicit_count;
54    top_shape[start_axis + inferred_axis_] = inferred_dim;
55  }
56  top[0]->Reshape(top_shape);
57  CHECK_EQ(top[0]->count(), bottom[0]->count())
58      << "output count must match input count";
59  top[0]->ShareData(*bottom[0]);
60  top[0]->ShareDiff(*bottom[0]);
61}
62
63INSTANTIATE_CLASS(ReshapeLayer);
64REGISTER_LAYER_CLASS(Reshape);
65
66}  // namespace caffe
67

代码似乎有些长,实际上很简单.后半部分是推断维度的,前半部分也很直观,就是做了比较多的check.

然后Reshape Layer是没有Forward函数的,因为没有做任何计算,只是改变了blob的reshape,也不存在数据的拷贝.

Posts in this Series