Vernlium/vernlium.github.io

Tensorflow源码解析——算子注册

Opened this issue · 0 comments

Tensorflow源码解析——算子注册

什么是op

op和kernel是TF框架中最重要的两个概念,如果一定要做一个类比的话,可以认为op相当于函数声明,kernel相当于函数实现。举个例子,对于矩阵相乘,可以声明一个op叫做MatMul,指明它的名称,输入,输出,参数,以及对参数的限制等。op只是告诉我们,这个操作的目的是什么,op内部有哪些可定制的东西,但不会提供具体实现。op在某种设备上的具体实现方法,是由kernel决定的。TF的计算图由节点构成,而每个节点对应了一个op,在构建计算图时,我们只知道不同节点对应的操作是什么,而不知道运行时这个操作是怎样实现的。也就是说,op是编译期概念,而kernel是运行期概念。

那为什么要把操作和它的实现分离呢?是为了实现TF代码的可移植性。我们可以把TF构建的计算图想象为Java的字节码,而计算图在执行的时候,需要考虑可用的设备资源,相当于我们在运行Java字节码的时候,需要考虑当前所在的操作系统,选择合适的字节码实现。因为TF的目标是在多设备上运行,但我们在编码的时候,是无法预先知道某一个操作具体是在哪种设备上运行的,因此,将op和它的实现分离,可以让我们在设计计算图的时候,更专注于它的结构,而不是具体实现。当我们构建完成一个计算图之后,在一个包含GPU的设备上,它可以利用对应操作在GPU上的kernel,充分利用GPU的高计算性能,在一个仅包含CPU的设备上,它也可以利用对应操作在CPU上的kenrel,完成计算功能。这就提高了TF代码在不同设备之间的可移植性。

注册方式

下面是tensorflow代码中注册Argmax算子的代码:

REGISTER_OP("ArgMax")
    .Input("input: T")
    .Input("dimension: Tidx")
    .Output("output: output_type")
    .Attr("T: numbertype")
    .Attr("Tidx: {int32, int64} = DT_INT32")
    .Attr("output_type: {int32, int64} = DT_INT64")
    .SetShapeFn(ArgOpShape);

通过REGISTER_OP宏进行算子注册,注册的内容有:

  • Input:算子的输入
  • Output:算子的输出
  • Attr:算子的属性,比如Argmax算子,有个属性是axis,在哪根轴上求最大值的下标
  • ShapeFn:用于shape推断

下面分析这个算子是如何被注册进去的。

OpDef

OpDef的定义在tensorflow\core\framework\op_def.proto

message OpDef {
  // Op names starting with an underscore are reserved for internal use.
  // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
  string name = 1;

  // For describing inputs and outputs.
  message ArgDef {
    // Name for the input/output.  Should match the regexp "[a-z][a-z0-9_]*".
    string name = 1;

    // Human readable description.
    string description = 2;

    DataType type = 3;
    string type_attr = 4;    // if specified, attr must have type "type"
    string number_attr = 5;  // if specified, attr must have type "int"
    // If specified, attr must have type "list(type)", and none of
    // type, type_attr, and number_attr may be specified.
    string type_list_attr = 6;

    bool is_ref = 16;
  };

OpDef中最核心的数据成员是操作名称、输入、输出、参数。

对于其中的几个难理解的点,作出说明:

  • ArgDef中的3-6四个字段,是为了描述·输入或输出的类型。当输入或输出是一个张量时,type或type_attr被设置为这个张量的数据类型,当输入或输出是一个由相同数据类型的张量构成的序列时,number_attr被设置为int对应的标识,当输入或输出是一个由张量构成的列表时,type_list_attr被设置为list(type)对应的标识;
  • AttrDef中的has_minimum字段,表明这个属性是否有最小值,如果数据类型是int,那么minimum就是允许的最小值,如果数据类型是列表,那么minimum就是列表的最短长度,is_aggregate这个字段,表明当前的操作是否是可聚集的。一个可聚集的操作是,能接受任意数量相同类型和形状的输入,并且保持输出与每个输入的类型和形状相同,这个字段对于操作的优化非常重要,如果一个操作是可聚集的,并且其输入来自多个不同的设备,那么我们就可以把聚集优化成一个树形的操作,先在设备内部对输入做聚集,最后在操作所在的设备集中,这样可以提高效率。这种优化对于分布式的机器学习模型训练非常有帮助,Spark ML中的TreeAggregate就实现了这样的优化。
  • is_stateful这个字段,表明当前的op是否带有状态的,什么操op会带有状态呢?比如Variable;

通过protoc工具用proto文件生成.h文件。命令为:

./protoc \ 
-I=/home/anan/tensorflow1.12/tensorflow-1.12.0/ \
--cpp_out=/home/anan/tensorflow1.12/tensorflow-1.12.0/tensorflow/core/framework/
/home/z00354782/tensorflow_1.12/tensorflow-
1.12.0/tensorflow/core/framework/op_def.proto

从中找到OpDef的定义:

class OpDef : public::google::protobuf::Message {
private:
    ::google::protobuf::RepeatedPtrField<::tensorflow::OpDef_ArgDef> input_arg_;
    ::google::protobuf::RepeatedPtrField<::tensorflow::OpDef_ArgDef> output_arg_;
    ::google::protobuf::RepeatedPtrField<::tensorflow::OpDef_ArgDef> attr_;
    ::google::protobuf::internal::ArenaStringPtr name_;
    ::google::protobuf::internal::ArenaStringPtr summary_;
    ::google::protobuf::internal::ArenaStringPtr description_;
    bool is_commutative_;
    bool is_aggregate_;
    bool is_stateful_;
    bool allows_uninitialized_input_;
}

为了方便进行OpDef的构建,TF还设计了OpDefBuilder类,它的私有数据成员如下:

// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
 public:
  // ...

 private:
  OpRegistrationData op_reg_data_;
  std::vector<string> attrs_;
  std::vector<string> inputs_;
  std::vector<string> outputs_;
  std::vector<string> control_outputs_;
  string doc_;
  std::vector<string> errors_;
};

可以看到,除了errors_字段外,其他内容几乎就是把OpDef的结构原封不动的搬了过来。

op_def_builder.h中还有一个新的结构,OpRegistrationData,他的结构如下:

struct OpRegistrationData {
 public:
  OpRegistrationData() {}
  OpRegistrationData(const OpDef& def) : op_def(def) {}
  OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
                     bool is_function = false)
      : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}

  OpDef op_def;
  OpShapeInferenceFn shape_inference_fn;
  bool is_function_op = false;
};

在这个结构中,除了屋面熟知的OpDef之外,还包含一个OpShapeInferenceFn结构,他的定义如下:

typedef std::function<Status(shape_inference::InferenceContext* c)>
    OpShapeInferenceFn;

这个结构的定义中,涉及到了我们后面要讲到的形状推断的内容,这里我们只需要知道,OpShapeInferenceFn是一个帮助操作根据输入形状对输出形状进行推断的函数即可。

Op注册

上面的例子中使用REGISTER_OP宏进行Op注册,看一下这个宏的定义:

#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name)                                          \
  static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
      TF_ATTRIBUTE_UNUSED =                                                  \
          ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
              name)>(name)

注:__COUNTER__宏表示自动计数,最终的定义是register_op0register_op1register_op2依次往后排。

  static ::tensorflow::register_op::OpDefBuilderReceiver register_op0   = \
    ::tensorflow::register_op::OpDefBuilderWrapper<true>("Argmax") \
                            .Input("input: T")
                            .Input("dimension: Tidx")
                            .Output("output: output_type")
                            .Attr("T: numbertype")
                            .Attr("Tidx: {int32, int64} = DT_INT32")
                            .Attr("output_type: {int32, int64} = DT_INT64")
                            .SetShapeFn(ArgOpShape);

也就是说,生成一个OpDefBuilderWrapper对象,并链式调用它的InputOutputAttr等方法。

OpDefBuilderWrapper的定义为:

// Template specialization that forwards all calls to the contained builder.
template <>
class OpDefBuilderWrapper<true> {
 public:
  explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
  OpDefBuilderWrapper<true>& Attr(string spec) {
    builder_.Attr(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper<true>& Input(string spec) {
    builder_.Input(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper<true>& Output(string spec) {
    builder_.Output(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper<true>& SetIsCommutative() {
    builder_.SetIsCommutative();
    return *this;
  }
  OpDefBuilderWrapper<true>& SetIsAggregate() {
    builder_.SetIsAggregate();
    return *this;
  }
  OpDefBuilderWrapper<true>& SetIsStateful() {
    builder_.SetIsStateful();
    return *this;
  }
  OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
    builder_.SetAllowsUninitializedInput();
    return *this;
  }
  OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
    builder_.Deprecated(version, std::move(explanation));
    return *this;
  }
  OpDefBuilderWrapper<true>& Doc(string text) {
    builder_.Doc(std::move(text));
    return *this;
  }
  OpDefBuilderWrapper<true>& SetShapeFn(
      Status (*fn)(shape_inference::InferenceContext*)) {
    builder_.SetShapeFn(fn);
    return *this;
  }
  const ::tensorflow::OpDefBuilder& builder() const { return builder_; }

 private:
  mutable ::tensorflow::OpDefBuilder builder_;
};

通过链式调用,把Input、Output、Attr等描述保存到OpDefBuiIder的attrs_、inputs_、outputs_属性中。例如,Input的处理为:

OpDefBuilder& OpDefBuilder::Input(string spec) {
  inputs_.push_back(std::move(spec));
  return *this;
}

OpDefBuilderWrapperOpDefBuilder的包装器,其成员包含一个OpDefBuilder的对象,它的API都是设置型的,且都返回对象本身,提供 链式的方式进行属性设置。值得注意的是,这个类名后面跟着一个true,它的含义等会再看。

最终把OpDefBuilderWrapper类型的对象用于构造OpDefBuilderReceiver

OpDefBuilderReceiver定义为:

struct OpDefBuilderReceiver {
  // To call OpRegistry::Global()->Register(...), used by the
  // REGISTER_OP macro below.
  // Note: These are implicitly converting constructors.
  OpDefBuilderReceiver(
      const OpDefBuilderWrapper<true>& wrapper);  // NOLINT(runtime/explicit)
  constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
  }  // NOLINT(runtime/explicit)
};
}  // namespace register_op

OpDefBuilderReceiver的构造函数的实现为:

OpDefBuilderReceiver::OpDefBuilderReceiver(
    const OpDefBuilderWrapper<true>& wrapper) {
  OpRegistry::Global()->Register(
      [wrapper](OpRegistrationData* op_reg_data) -> Status {
        return wrapper.builder().Finalize(op_reg_data);
      });
}

相当于是OpDefBuilderWrapper构造时,以OpDefBuilderWrapper为参数,在构造函数中调用OpRegistry::Global()->Register(...)

也就是说,REGISTER_OP绕了一圈,先用OpDefBuilderWrapper对操作进行封装,然后把它作为参数传递给OpDefBuilderReceiver的构造函数,而在这个构造函数中,完成了对算子的注册。

真正的注册过程就是OpRegistryRegister方法中完成的,下面具体看一下注册类的实现。

注册类

为了方便对操作进行统一管理,TF提出了OP注册器的概念。这个OP注册器的作用,是为各种OP提供一个统一的管理接囗。

操作注册类的继承结构如下:

image

其中,OpRegistryInterface是一个接口类,它提供了注册类最基础的查找功能:

// Users that want to look up an OpDef by type name should take an
// OpRegistryInterface.  Functions accepting a
// (const) OpRegistryInterface* may call LookUp() from multiple threads.
class OpRegistryInterface {
 public:
  virtual ~OpRegistryInterface();

  // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
  // registered under that name, otherwise returns the registered OpDef.
  // Caller must not delete the returned pointer.
  virtual Status LookUp(const string& op_type_name,
                        const OpRegistrationData** op_reg_data) const = 0;

  // Shorthand for calling LookUp to get the OpDef.
  Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
};

OpRegistry类继承了OpRegistryInterface类。

// The standard implementation of OpRegistryInterface, along with a
// global singleton used for registering ops via the REGISTER
// macros below.  Thread-safe.
//
// Example registration:
//   OpRegistry::Global()->Register(
//     [](OpRegistrationData* op_reg_data)->Status {
//       // Populate *op_reg_data here.
//       return Status::OK();
//   });
class OpRegistry : public OpRegistryInterface {
 public:
  typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;

  OpRegistry();
  ~OpRegistry() override;

  void Register(const OpRegistrationDataFactory& op_data_factory);

  Status LookUp(const string& op_type_name,
                const OpRegistrationData** op_reg_data) const override;

  // Fills *ops with all registered OpDefs (except those with names
  // starting with '_' if include_internal == false) sorted in
  // ascending alphabetical order.
  void Export(bool include_internal, OpList* ops) const;

  // Returns ASCII-format OpList for all registered OpDefs (except
  // those with names starting with '_' if include_internal == false).
  string DebugString(bool include_internal) const;

  // A singleton available at startup.
  static OpRegistry* Global();

  // Get all registered ops.
  void GetRegisteredOps(std::vector<OpDef>* op_defs);

  // Get all `OpRegistrationData`s.
  void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);

  // Watcher, a function object.
  // The watcher, if set by SetWatcher(), is called every time an op is
  // registered via the Register function. The watcher is passed the Status
  // obtained from building and adding the OpDef to the registry, and the OpDef
  // itself if it was successfully built. A watcher returns a Status which is in
  // turn returned as the final registration status.
  typedef std::function<Status(const Status&, const OpDef&)> Watcher;

  // An OpRegistry object has only one watcher. This interface is not thread
  // safe, as different clients are free to set the watcher any time.
  // Clients are expected to atomically perform the following sequence of
  // operations :
  // SetWatcher(a_watcher);
  // Register some ops;
  // op_registry->ProcessRegistrations();
  // SetWatcher(nullptr);
  // Returns a non-OK status if a non-null watcher is over-written by another
  // non-null watcher.
  Status SetWatcher(const Watcher& watcher);

  // Process the current list of deferred registrations. Note that calls to
  // Export, LookUp and DebugString would also implicitly process the deferred
  // registrations. Returns the status of the first failed op registration or
  // Status::OK() otherwise.
  Status ProcessRegistrations() const;

  // Defer the registrations until a later call to a function that processes
  // deferred registrations are made. Normally, registrations that happen after
  // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
  // immediately. Call this to defer future registrations.
  void DeferRegistrations();

  // Clear the registrations that have been deferred.
  void ClearDeferredRegistrations();

 private:
  // Ensures that all the functions in deferred_ get called, their OpDef's
  // registered, and returns with deferred_ empty.  Returns true the first
  // time it is called. Prints a fatal log if any op registration fails.
  bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);

  // Calls the functions in deferred_ and registers their OpDef's
  // It returns the Status of the first failed op registration or Status::OK()
  // otherwise.
  Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);

  // Add 'def' to the registry with additional data 'data'. On failure, or if
  // there is already an OpDef with that name registered, returns a non-okay
  // status.
  Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
      const EXCLUSIVE_LOCKS_REQUIRED(mu_);

  Status LookUpSlow(const string& op_type_name,
                    const OpRegistrationData** op_reg_data) const;

  mutable mutex mu_;
  // Functions in deferred_ may only be called with mu_ held.
  mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
  // Values are owned.
  mutable std::unordered_map<string, const OpRegistrationData*> registry_
      GUARDED_BY(mu_);
  mutable bool initialized_ GUARDED_BY(mu_);

  // Registry watcher.
  mutable Watcher watcher_ GUARDED_BY(mu_);
};

OpRegistry类是单例模式,通过Global获取单例对象,并且是线程安全的。

注册函数Register的定义为:

void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
  mutex_lock lock(mu_);
  if (initialized_) {
    TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
  } else {
    deferred_.push_back(op_data_factory);
  }
}

其中,OpRegistrationDataFactory是一个function类型:

typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;

也就是说,Register注册时传入的是一个函数,最终在Register中完成对函数的调用。

从代码看,只有RegisterAlreadyLocked(op_data_factory)中可能产生对op_data_factory的调用,所以可以从这儿入手看注册过程。姑且不论initialized_字段的值。

// Add 'def' to the registry with additional data 'data'. On failure, or if
// there is already an OpDef with that name registered, returns a non-okay
// status.
Status OpRegistry::RegisterAlreadyLocked(
    const OpRegistrationDataFactory& op_data_factory) const {
  std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
  Status s = op_data_factory(op_reg_data.get());
  if (s.ok()) {
    s = ValidateOpDef(op_reg_data->op_def);
    if (s.ok() &&
        !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
                                 op_reg_data.get())) {
      s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
    }
  }
  Status watcher_status = s;
  if (watcher_) {
    watcher_status = watcher_(s, op_reg_data->op_def);
  }
  if (s.ok()) {
    op_reg_data.release();
  } else {
    op_reg_data.reset();
  }
  return watcher_status;
}

函数的注释写的很清楚了,新增一个def到register中。失败或者算子name已经被注册,返回非okey结果。

这个函数中构造了一个OpRegistrationData对象,并最终对op_data_factory进行了调用。

OpRegistrationData的定义如下,其中包含了一个OpDef的变量。

struct OpRegistrationData {
 public:
  OpRegistrationData() {}
  OpRegistrationData(const OpDef& def) : op_def(def) {}
  OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
                     bool is_function = false)
      : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}

  OpDef op_def;
  OpShapeInferenceFn shape_inference_fn;
  bool is_function_op = false;
};

op_data_factory的调用构造了一个OpRegistrationData空对象,最终进入wrapper.builder().Finalize(op_reg_data)中进行处理。

wrapper.builder()返回的是OpDefBuilder对象。函数Finalize的实现为:

Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
  std::vector<string> errors = errors_;
  *op_reg_data = op_reg_data_;

  OpDef* op_def = &op_reg_data->op_def;
  for (StringPiece attr : attrs_) {
    FinalizeAttr(attr, op_def, &errors);
  }
  for (StringPiece input : inputs_) {
    FinalizeInputOrOutput(input, false, op_def, &errors);
  }
  for (StringPiece output : outputs_) {
    FinalizeInputOrOutput(output, true, op_def, &errors);
  }
  for (StringPiece control_output : control_outputs_) {
    FinalizeControlOutput(control_output, op_def, &errors);
  }
  FinalizeDoc(doc_, op_def, &errors);

  if (errors.empty()) return Status::OK();
  return errors::InvalidArgument(str_util::Join(errors, "\n"));
}

这里把最开始wrapper中保存的inputs_outputs_attrs_等信息依次取出,用于构建OpDef对象。

得到的OpDef对象首先经过ValidateOpDef(op_reg_data->op_def);进行校验,然后插入到Registerregistry_中。

gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
                                 op_reg_data.get()))

到这里就完成了一个算子的注册过程。

下面这个代码值得注意:

  if (initialized_) {
    TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
  } else {
    deferred_.push_back(op_data_factory);
  }

只有在initialized_是true时,才进行注册,否则把op_data_factory放到deferred_这个vector中。

注意到Register类有如下两个方法:

// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty.  Returns true the first
// time it is called. Prints a fatal log if any op registration fails.
bool OpRegistry::MustCallDeferred() const {
  if (initialized_) return false;
  initialized_ = true;
  for (size_t i = 0; i < deferred_.size(); ++i) {
    TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
  }
  deferred_.clear();
  return true;
}

// Calls the functions in deferred_ and registers their OpDef's
// It returns the Status of the first failed op registration or Status::OK()
// otherwise.
Status OpRegistry::CallDeferred() const {
  if (initialized_) return Status::OK();
  initialized_ = true;
  for (size_t i = 0; i < deferred_.size(); ++i) {
    Status s = RegisterAlreadyLocked(deferred_[i]);
    if (!s.ok()) {
      return s;
    }
  }
  deferred_.clear();
  return Status::OK();
}

可以看出,在特定的调用中,把deferred_中保存的算子注册函数全部取出,执行RegisterAlreadyLocked真正的执行算子注册过程。

这里有几点值得关注:

  • 注册函数Register的输入是一个函数引用,这个函数接收一个OpRegistrationData指针作为输入;
  • Watcher是一个监视器,当每次注册一个算子的时候,在注册步骤的最后都要调用一下这个监视器,它可方便对注册的操作进行监控,所有的算子注册动作都逃不过它的眼,可以根据需求定制特殊Watcher;
  • registry_`是已注册的算子真正存放的位置,它的结构很简单,是一个key为算子名、value为算子数据的map;
  • initialized_deferred_是与注册模式相关的两个数据成员,注册器在注册操作时,分为两种模式:
    • 即时注册模式懒惰注册模式
    • 注册模式通过initialized_字段区分,true为即时注册模式,false为懒惰注册模式;
    • 在懒惰注册模式中,被注册的算子先 被保存在deferred_向量中,在特定的函数调用时再将deferred_中的算子注册到registryy_,而即时注册模式下,待注册的算子不用经过deferred_,直接注册到registry_
      -懒惰注册模式的使用场景是,部分算子组合的注册是原子的,即要么全部注册,要么全部不注册,因为这些算子之间可能会有相互依赖关系。
    • 构造函数将initialized_设置为false,进入懒惰注册模式,随后一旦调用了MustCallDeferred或者CallDeferred中的任意一个,都会将initialized_设置为true,进入即时注册模式。想要重新返回懒惰注册模式也很简单,只需要调用DeferRegistrations即可。

参考

https://www.cnblogs.com/jicanghai/p/9539513.html

注:文中代码基于tensorflow1.12.0版本。