153 lines
4.4 KiB
C++
153 lines
4.4 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/mobile/nnc/context.h>
|
|
#include <torch/csrc/jit/mobile/nnc/registry.h>
|
|
#include <ATen/Functions.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace mobile {
|
|
namespace nnc {
|
|
|
|
extern "C" {
|
|
|
|
// out = a * n (doing calculation in the `tmp` buffer)
|
|
int slow_mul_kernel(void** args) {
|
|
const int size = 128;
|
|
at::Tensor a = at::from_blob(args[0], {size}, at::kFloat);
|
|
at::Tensor out = at::from_blob(args[1], {size}, at::kFloat);
|
|
at::Tensor n = at::from_blob(args[2], {1}, at::kInt);
|
|
at::Tensor tmp = at::from_blob(args[3], {size}, at::kFloat);
|
|
|
|
tmp.zero_();
|
|
for (int i = n.item().toInt(); i > 0; i--) {
|
|
tmp.add_(a);
|
|
}
|
|
out.copy_(tmp);
|
|
return 0;
|
|
}
|
|
|
|
int dummy_kernel(void** /* args */) {
|
|
return 0;
|
|
}
|
|
|
|
} // extern "C"
|
|
|
|
REGISTER_NNC_KERNEL("slow_mul", slow_mul_kernel)
|
|
REGISTER_NNC_KERNEL("dummy", dummy_kernel)
|
|
|
|
InputSpec create_test_input_spec(const std::vector<int64_t>& sizes) {
|
|
InputSpec input_spec;
|
|
input_spec.sizes_ = sizes;
|
|
input_spec.dtype_ = at::kFloat;
|
|
return input_spec;
|
|
}
|
|
|
|
OutputSpec create_test_output_spec(const std::vector<int64_t>& sizes) {
|
|
OutputSpec output_spec;
|
|
output_spec.sizes_ = sizes;
|
|
output_spec.dtype_ = at::kFloat;
|
|
return output_spec;
|
|
}
|
|
|
|
MemoryPlan create_test_memory_plan(const std::vector<int64_t>& buffer_sizes) {
|
|
MemoryPlan memory_plan;
|
|
memory_plan.buffer_sizes_ = buffer_sizes;
|
|
return memory_plan;
|
|
}
|
|
|
|
TEST(Function, ExecuteSlowMul) {
|
|
const int a = 999;
|
|
const int n = 100;
|
|
const int size = 128;
|
|
Function f;
|
|
|
|
f.set_nnc_kernel_id("slow_mul");
|
|
f.set_input_specs({create_test_input_spec({size})});
|
|
f.set_output_specs({create_test_output_spec({size})});
|
|
f.set_parameters(c10::impl::toList(c10::List<at::Tensor>({
|
|
at::ones({1}, at::kInt).mul(n)
|
|
})));
|
|
f.set_memory_plan(create_test_memory_plan({sizeof(float) * size}));
|
|
|
|
c10::List<at::Tensor> input({
|
|
at::ones({size}, at::kFloat).mul(a)
|
|
});
|
|
auto outputs = f.run(c10::impl::toList(input));
|
|
auto output = ((const c10::IValue&) outputs[0]).toTensor();
|
|
auto expected_output = at::ones({size}, at::kFloat).mul(a * n);
|
|
EXPECT_TRUE(output.equal(expected_output));
|
|
}
|
|
|
|
TEST(Function, Serialization) {
|
|
Function f;
|
|
f.set_name("test_function");
|
|
f.set_nnc_kernel_id("test_kernel");
|
|
f.set_input_specs({create_test_input_spec({1, 3, 224, 224})});
|
|
f.set_output_specs({create_test_output_spec({1000})});
|
|
|
|
f.set_parameters(c10::impl::toList(c10::List<at::Tensor>({
|
|
at::ones({1, 16, 3, 3}, at::kFloat),
|
|
at::ones({16, 32, 1, 1}, at::kFloat),
|
|
at::ones({32, 1, 3, 3}, at::kFloat)
|
|
})));
|
|
f.set_memory_plan(create_test_memory_plan({
|
|
sizeof(float) * 1024,
|
|
sizeof(float) * 2048,
|
|
}));
|
|
|
|
auto serialized = f.serialize();
|
|
Function f2(serialized);
|
|
EXPECT_EQ(f2.name(), "test_function");
|
|
EXPECT_EQ(f2.nnc_kernel_id(), "test_kernel");
|
|
EXPECT_EQ(f2.input_specs().size(), 1);
|
|
EXPECT_EQ(f2.input_specs()[0].sizes_, std::vector<int64_t>({1, 3, 224, 224}));
|
|
EXPECT_EQ(f2.input_specs()[0].dtype_, at::kFloat);
|
|
|
|
EXPECT_EQ(f2.output_specs().size(), 1);
|
|
EXPECT_EQ(f2.output_specs()[0].sizes_, std::vector<int64_t>({1000}));
|
|
EXPECT_EQ(f2.output_specs()[0].dtype_, at::kFloat);
|
|
|
|
EXPECT_EQ(f2.parameters().size(), 3);
|
|
EXPECT_EQ(f2.parameters()[0].toTensor().sizes(), at::IntArrayRef({1, 16, 3, 3}));
|
|
EXPECT_EQ(f2.parameters()[1].toTensor().sizes(), at::IntArrayRef({16, 32, 1, 1}));
|
|
EXPECT_EQ(f2.parameters()[2].toTensor().sizes(), at::IntArrayRef({32, 1, 3, 3}));
|
|
|
|
EXPECT_EQ(f2.memory_plan().buffer_sizes_.size(), 2);
|
|
EXPECT_EQ(f2.memory_plan().buffer_sizes_[0], sizeof(float) * 1024);
|
|
EXPECT_EQ(f2.memory_plan().buffer_sizes_[1], sizeof(float) * 2048);
|
|
}
|
|
|
|
TEST(Function, ValidInput) {
|
|
const int size = 128;
|
|
Function f;
|
|
f.set_nnc_kernel_id("dummy");
|
|
f.set_input_specs({create_test_input_spec({size})});
|
|
|
|
c10::List<at::Tensor> input({
|
|
at::ones({size}, at::kFloat)
|
|
});
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
EXPECT_NO_THROW(
|
|
f.run(c10::impl::toList(input)));
|
|
}
|
|
|
|
TEST(Function, InvalidInput) {
|
|
const int size = 128;
|
|
Function f;
|
|
f.set_nnc_kernel_id("dummy");
|
|
f.set_input_specs({create_test_input_spec({size})});
|
|
|
|
c10::List<at::Tensor> input({
|
|
at::ones({size * 2}, at::kFloat)
|
|
});
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
EXPECT_THROW(
|
|
f.run(c10::impl::toList(input)),
|
|
c10::Error);
|
|
}
|
|
|
|
} // namespace nnc
|
|
} // namespace mobile
|
|
} // namespace jit
|
|
} // namespace torch
|