TensorFlow Architecture 学习笔记(二)Adding a New Op
1. 在C++文件中注册该运算。包含参数个数,类型,返回值类型,运算名称等。大概就是C++头文件中函数的定义吧.同时在此处也要定义shape function
2. 用C++实现该运算,运算的实现被称之为kernel,该实现与第一步中的注册对应,相同功能的op对于不同的输入输出类型,或者是架构(GPU,CPU)可能 有不同的实现.
3. 创建一个python包装器(**可选**),该包装器使得可以在python层面用API对该运算进行调用。在运算注册阶段会生成默认的包装器
4. 编写计算该运算梯度 的函数**(可选)**
5. 测试你新添加的运算。为了方便一般在python层面进行测试,不过你想在C++层面测试也没什么问题。
Define the op's interface
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
// REGISTER_OP("my_op_name")
// .Attr(":")
// .Attr(":=")
// .Input(":")
// .Input(":Ref()")
// .Output(":")
// .Doc(R"(
// :
// :
// )");
// Note: .Doc() should be last.
// For details, see the OpDefBuilder class in op_def_builder.h.
// registration.
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
// Constructs an OpDef with just the name field set.
explicit OpDefBuilder(StringPiece op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
// where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*
// (by convention only using capital letters for attrs that can be inferred)
// <type> can be:
// "string", "int", "float", "bool", "type", "shape", or "tensor"
// "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
// (meaning "type" with a restriction on valid values)
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
// (meaning "string" with a restriction on valid values)
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
// (meaning lists of the above types)
// "int >= 2" (meaning "int" with a restriction on valid values)
// "list(string) >= 2", "list(int) >= 2"
// (meaning "list(string)" / "list(int)" with length at least 2)
// <default>, if included, should use the Proto text format
// of <type>. For lists use [a, b, c] format.
// Note that any attr specifying the length of an input or output will
// get a default minimum of 1 unless the >= # syntax is used.
// TODO(josh11b): Perhaps support restrictions and defaults as optional
// extra arguments to Attr() instead of encoding them in the spec string.
// TODO(josh11b): Would like to have better dtype handling for tensor attrs:
// * Ability to say the type of an input/output matches the type of
// the tensor.
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
OpDefBuilder& Attr(StringPiece spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
// * For a single tensor: <type>
// * For a sequence of tensors with the same type: <number>*<type>
// * For a sequence of tensors with different types: <type-list>
// Where:
// <type> is either one of "float", "int32", "string", ...
// or the name of an attr (see above) with type "type".
// <number> is the name of an attr with type "int".
// <type-list> is the name of an attr with type "list(type)".
// TODO(josh11b): Indicate Ref() via an optional argument instead of
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
OpDefBuilder& Input(StringPiece spec);
OpDefBuilder& Output(StringPiece spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
OpDefBuilder& SetIsCommutative();
OpDefBuilder& SetIsAggregate();
OpDefBuilder& SetIsStateful();
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
OpDefBuilder& Deprecated(int version, StringPiece explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
// <1-line summary>
// <rest of the description>
// <name>: <description of name>
// <name>: <description of name>
// <if long, indent the description on subsequent lines>
// Where <name> is the name of an attr, input, or output. Please
// wrap docs at 72 columns so that it may be indented in the
// generated output. For tensor inputs or outputs (not attrs), you
// may start the description with an "=" (like name:= <description>)
// to suppress the automatically-generated type documentation in
// generated output.
OpDefBuilder& Doc(StringPiece text);
// Sets *op_def to the requested OpDef, or returns an error.
// Must be called after all of the above methods.
// Note that OpDefBuilder only reports parsing errors. You should also
// call ValidateOpDef() to detect other problems.
Status Finalize(OpDef* op_def) const;
OpDef op_def_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
string doc_;
std::vector<string> errors_;
} // namespace tensorflow
Implement the kernel for the op
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
method 一定是线程安全的。
可以使用 ResourceMgr
Multi-threaded CPU kernels
要实现多线程cpu版本的kernel,可以参考 work_sharder.h
GPU kernels