method
See:
./code/method/main.cc
1#include "torch/script.h"
2
3static void TestHello() {
4 torch::jit::Module m("m");
5 m.define(R"(
6 def forward(self, x: torch.Tensor, y: torch.Tensor):
7 return x + y
8 )");
9
10 torch::jit::Method method = m.get_method("forward");
11 TORCH_CHECK(method.name() == "forward");
12
13 const std::vector<std::string> &names = method.getArgumentNames();
14 TORCH_CHECK(names.size() == 2);
15 TORCH_CHECK(names[0] == "x");
16 TORCH_CHECK(names[1] == "y");
17
18 std::vector<torch::IValue> args;
19 auto x = torch::tensor({1, 2});
20 auto y = torch::tensor({1, 2});
21 args.emplace_back(x);
22 args.emplace_back(y);
23 auto z = method(args).toTensor();
24
25 TORCH_CHECK(torch::equal(z, x + y));
26
27 std::shared_ptr<torch::jit::Graph> g = method.graph();
28 // see node/main.cc
29}
30
31int main() {
32 TestHello();
33 return 0;
34}