ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数

科技资讯 投稿 36200 0 评论

ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数

前言

为了深入理解ONNX Runtime的底层机制,本文将对 Graph::SetGraphInputsOutputs( 的代码逐行分析。

正文

首先判断是否从ONNX文件中加载所得:

if (is_loaded_from_model_file_ return Status::OK(;

如果是,可直接返回;如果不是,则需要解析中的节点,从而设置模型的输入和输出。

将中的成员变量 、、 以及  全部清空:

value_info_.clear(;

graph_inputs_excluding_initializers_.clear(;

if (!graph_inputs_manually_set_ {
  graph_inputs_including_initializers_.clear(;
} else {
  std::unordered_set<std::string> existing_names;
  for (auto arg : graph_inputs_including_initializers_ {
    const std::string& name = arg->Name(;
    if (existing_names.count(name == 0 {
      graph_inputs_excluding_initializers_.push_back(arg;
      existing_names.insert(name;
    }
  }
}

if (!graph_outputs_manually_set_ {
  graph_outputs_.clear(;
}

设置一些局部变量,方便下面的使用分析:

std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
std::vector<const NodeArg*> output_node_args_in_order;
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};

统计所有节点的输出,添加到以上局部变量(output_name_to_node_arg_index 和 output_node_args_in_order)中:

for (const auto& node : Nodes( {
  for (const auto* output_def : node.OutputDefs( {
    if (output_def->Exists( {
      output_node_args_in_order.push_back(output_def;
      output_name_to_node_arg_index.insert({output_def->Name(, output_node_args_in_order.size( - 1};
    }
  }
}
auto graph_output_args = output_name_to_node_arg_index;  // 拷贝一份输出节点map

然后遍历图中每个节点以及每个节点的输入:

for (const auto& node : Nodes( {
  // Go thru all node's inputs.
  for (const auto* input_arg : node.InputDefs( {
    ...
  }
}

在输出节点列表中查找当前输入:

auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name(;

如果没有找到,说明这个节点的输入就是图的输入,接下来还需要判断这个输入是否已经放在局部变量中:

if (output_name_to_node_arg_index.end( == output_arg_iter {
  // This input arg is not the output of another node so must come from either a graph input or an initializer.
  const std::string& name = input_arg->Name(;
  if (added_input_names.end( == added_input_names.find(name {
    ...
  }
}

如果已经放到局部变量中,就可以判断节点的下一个输入或者下一个节点的输入。如果没有放到局部变量中:

bool is_initializer = name_to_initial_tensor_.find(name != name_to_initial_tensor_.end(;  // 判断当前input_arg是否已初始化过的tensor,如果是就不可以再放置到 graph_inputs_excluding_initializers_ 中
if (!graph_inputs_manually_set_ {   // 如果未主动调用 SetInputs( 方法
  // if IR version < 4 all initializers must have a matching graph input
  // (even though the graph input is not allowed to override the initializer.
  // if IR version >= 4 initializers are not required to have a matching graph input.
  // any graph inputs that are to override initializers must be specified by calling SetInputs.
  if (!is_initializer || ir_version_ < 4 {
    graph_inputs_including_initializers_.push_back(input_arg;
  }
  if (!is_initializer {
    // If input_arg is not of an initializer, we add it into graph_inputs_excluding_initializers_.
    graph_inputs_excluding_initializers_.push_back(input_arg;
  }
} else {  // 如果主动调用了 SetInputs( 方法
  // graph_inputs_including_initializers_ has been manually populated by SetInputs.
  // Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
  if (!is_initializer {
    const auto& inputs = graph_inputs_including_initializers_;
    bool in_inputs = std::find(inputs.begin(, inputs.end(, input_arg != inputs.end(;
    if (!in_inputs {
      return Status(ONNXRUNTIME, FAIL,
                    name + " must be either specified in graph inputs or graph initializers.";
    }
  } else {
    // If arg_input is of an initializer, we remove it from graph_inputs_excluding_initializers_
    // whose initial content has both initializers and non-initializers.
    auto input_pos = std::find(graph_inputs_excluding_initializers_.begin(,
                                graph_inputs_excluding_initializers_.end(,
                                input_arg;
    if (input_pos != graph_inputs_excluding_initializers_.end( {
      graph_inputs_excluding_initializers_.erase(input_pos;
    }
  }
}
added_input_names.insert(name;

可以看到,这里会把当前的  分别放到  和  中,并将放在中。

如果该输入的已经在输出节点列表中,说明这个节点是中间输出结果,而非整个图的输出,因此应该将其从图的输出()中删除,并放在  中:

if (output_name_to_node_arg_index.end( == output_arg_iter {
  ...
}else if(graph_output_args.erase(output_arg_iter->first >= 1{
  value_info_.insert(input_arg;
}

以上我们对Graph的三个成员变量:、和分别进行了赋值,其中前两者存储输入,后者存储中间结果。我们还需要处理图的输出结果:`graph_outputs_`:

if (!graph_outputs_manually_set_ {
  // Set graph outputs in order.
  std::vector<size_t> graph_output_args_index;
  graph_output_args_index.reserve(graph_output_args.size(;
  for (const auto& output_arg : graph_output_args {          // graph_output_args原本存储了所有节点的输出,但是前面的代码已经把中间节点的输出给移除了,因此剩下的就是整个Graph的输出
    graph_output_args_index.push_back(output_arg.second;
  }

  std::sort(graph_output_args_index.begin(, graph_output_args_index.end(;
  for (auto& output_arg_index : graph_output_args_index {
    graph_outputs_.push_back(output_node_args_in_order[output_arg_index];
  }
}

最后,还需要对  进行处理:

ComputeOverridableInitializers(;

进入这个函数内部:

void Graph::ComputeOverridableInitializers( {
  graph_overridable_initializers_.clear(;
  if (CanOverrideInitializer( {
    // graph_inputs_excluding_initializers_ and graph_inputs_including_initializers_
    // are inserted in the same order. So we walk and compute the difference.
    auto f_incl = graph_inputs_including_initializers_.cbegin(;
    const auto l_incl = graph_inputs_including_initializers_.cend(;
    auto f_excl = graph_inputs_excluding_initializers_.cbegin(;
    const auto l_excl = graph_inputs_excluding_initializers_.cend(;

    while (f_incl != l_incl {
      // Equal means not an initializer
      if (f_excl != l_excl && *f_incl == *f_excl {
        ++f_incl;
        ++f_excl;
        continue;
      }
      graph_overridable_initializers_.push_back(*f_incl;
      ++f_incl;
    }
  }
}

这是一个很简单的算法,通过比较  和 ,提取出  并放置到  中。

至此,我们完成了对  函数的解析。

总结

针对这个函数的解析不仅理解了如何从Graph的nodes中分析出graph的输入和输出,而且懂得了以及的作用。

编程笔记 » ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数

赞同 (35) or 分享 (0)
游客 发表我的评论   换个身份
取消评论

表情
(0)个小伙伴在吐槽