method

See:

./code/method/main.cc
 1#include "torch/script.h"
 2
 3static void TestHello() {
 4  torch::jit::Module m("m");
 5  m.define(R"(
 6    def forward(self, x: torch.Tensor, y: torch.Tensor):
 7      return x + y
 8  )");
 9
10  torch::jit::Method method = m.get_method("forward");
11  TORCH_CHECK(method.name() == "forward");
12
13  const std::vector<std::string> &names = method.getArgumentNames();
14  TORCH_CHECK(names.size() == 2);
15  TORCH_CHECK(names[0] == "x");
16  TORCH_CHECK(names[1] == "y");
17
18  std::vector<torch::IValue> args;
19  auto x = torch::tensor({1, 2});
20  auto y = torch::tensor({1, 2});
21  args.emplace_back(x);
22  args.emplace_back(y);
23  auto z = method(args).toTensor();
24
25  TORCH_CHECK(torch::equal(z, x + y));
26
27  std::shared_ptr<torch::jit::Graph> g = method.graph();
28  // see node/main.cc
29}
30
31int main() {
32  TestHello();
33  return 0;
34}