#include #include #include #ifdef __SIGRID_USE_GPU__ #include #include #endif namespace torch::nativert { C10Kernel::C10Kernel( const Node* node, c10::Device device, OpKernelKind kind, AliasingSpec&& aliasingSpec) : OpKernel(node, device, kind), op_(getOperatorForTarget(node->target(), node)), schema_(op_.schema(), std::move(aliasingSpec), kind_), arguments_(prefillStackWithStaticArgs(node, op_.schema())) {} void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const { // Make a copy of the stack std::vector stack = arguments_.getStackWithStaticArgs(); fillDynamicInputs(executionFrame, arguments_, stack); // Call the op with the prepared stack. try { op_.callBoxed(stack); } catch (const std::exception& ex) { auto stackTrace = node_->getMetadata("stack_trace"); throw std::runtime_error(fmt::format( "Exception while executing node: {}\n" "with args:\n{}\n" "{}\n" "Original Python stacktrace:\n{}", fmt::streamed(*node_), readableArgs(op_.schema(), stack), ex.what(), stackTrace ? *stackTrace : "")); } // Write out results // TODO: we store intermediates in a single table (symint and tensor alike). // This can theoretically lead to name collisions, although based on how // these are named I don't think it will ever happen in practice. We need to // enforce it though. const auto& outputValues = node_->outputs(); TORCH_CHECK_EQ(outputValues.size(), stack.size()) << "Output size mismatch for " << node_->toString(); for (auto&& [i, actualOutput] : c10::enumerate(stack)) { executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput)); } } namespace { std::unordered_map getSymInputs( const ExecutionFrame& executionFrame, const Node& node) { std::unordered_map inputs; for (const auto& input : node.inputs()) { const auto& val = executionFrame.getIValue(input.value->id()); if (val.isInt() || val.isDouble() || val.isBool()) { inputs[input.name] = val; } else { throw std::runtime_error("unsupported type for symbolic input"); } } for (const auto& attribute : node.attributes()) { if (std::holds_alternative(attribute.value)) { inputs[attribute.name] = std::get(attribute.value); } else if (std::holds_alternative(attribute.value)) { inputs[attribute.name] = std::get(attribute.value); } else if (std::holds_alternative(attribute.value)) { inputs[attribute.name] = std::get(attribute.value); } else { throw std::runtime_error("unsupported type for symbolic input"); } } return inputs; } template void computeScalarBinaryOp( ExecutionFrame& executionFrame, const Node& node, std::enable_if_t a, std::enable_if_t b) { std::string_view target = node.target(); T out; if (target == "_operator.add") { out = a + b; } else if (target == "_operator.sub") { out = a - b; } else if (target == "_operator.mul") { out = a * b; } else if (target == "_operator.pow") { out = std::pow(a, b); } else { throw std::runtime_error( fmt::format("unsupported operator for symbolic values: {}", target)); } executionFrame.setIValue(node.outputs()[0]->id(), out); VLOG(2) << fmt::format( "Completed executing node: {} with a={}, b={}, out={}", fmt::streamed(node), a, b, out); } } // namespace void ScalarBinaryOpKernel::computeInternal( ExecutionFrame& executionFrame) const { auto inputs = getSymInputs(executionFrame, *node_); const auto& a = inputs.at("a"); const auto& b = inputs.at("b"); auto coerceToDouble = [](const c10::IValue& x) -> double { if (x.isInt()) { return static_cast(x.toInt()); } else if (x.isDouble()) { return x.toDouble(); } else { throw std::runtime_error("unsupported type for symbolic input"); } }; if (a.isInt() && b.isInt()) { computeScalarBinaryOp( executionFrame, *node_, a.toInt(), b.toInt()); } else { computeScalarBinaryOp( executionFrame, *node_, coerceToDouble(a), coerceToDouble(b)); } } void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const { auto inputs = getSymInputs(executionFrame, *node_); int64_t a = inputs.at("a").toInt(); std::string_view target = node_->target(); if (target == "torch.sym_float") { double out = static_cast(a); executionFrame.setIValue(node_->outputs()[0]->id(), out); VLOG(2) << fmt::format( "Completed executing node: {} with a={}, out={}", fmt::streamed(*node_), a, out); return; } int64_t b = inputs.at("b").toInt(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t out; if (target == "_operator.floordiv") { out = a / b; } else if (target == "_operator.mod") { out = a % b; } else if (target == "torch.sym_max") { out = std::max(a, b); } else if (target == "torch.sym_min") { out = std::min(a, b); } else { throw std::runtime_error( fmt::format("unsupported operator for SymInt: {}", node_->target())); } executionFrame.setIValue(node_->outputs()[0]->id(), out); VLOG(2) << fmt::format( "Completed executing node: {} with a={}, b={}, out={}", fmt::streamed(*node_), a, b, out); } void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const { auto inputs = getSymInputs(executionFrame, *node_); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool out; const std::string_view target = node_->target(); if (target == "torch.sym_not") { bool a = inputs.at("a").toBool(); out = !a; } else if (target == "_operator.ge") { int64_t a = inputs.at("a").toInt(); int64_t b = inputs.at("b").toInt(); out = a >= b; } else if (target == "_operator.le") { int64_t a = inputs.at("a").toInt(); int64_t b = inputs.at("b").toInt(); out = a <= b; } else if (target == "_operator.eq") { int64_t a = inputs.at("a").toInt(); int64_t b = inputs.at("b").toInt(); out = a == b; } else if (target == "_operator.gt") { int64_t a = inputs.at("a").toInt(); int64_t b = inputs.at("b").toInt(); out = a > b; } else if (target == "_operator.lt") { int64_t a = inputs.at("a").toInt(); int64_t b = inputs.at("b").toInt(); out = a < b; } else if (target == "_operator.and_") { bool a = inputs.at("a").toBool(); bool b = inputs.at("b").toBool(); out = a && b; } else { throw std::runtime_error( fmt::format("unsupported operator for SymBool: {}", node_->target())); } executionFrame.setIValue(node_->outputs()[0]->id(), out); } void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const { auto inputs = getSymInputs(executionFrame, *node_); const std::string_view target = node_->target(); if (target == "math.trunc") { double x = inputs.at("x").toDouble(); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) int64_t out = trunc(x); executionFrame.setIValue(node_->outputs()[0]->id(), out); } else if (target == "torch._sym_sqrt") { double a = inputs.at("a").toDouble(); double out = std::sqrt(a); executionFrame.setIValue(node_->outputs()[0]->id(), out); } else if (target == "_operator.neg") { auto a = inputs.at("a"); c10::IValue out; if (a.isInt()) { out = -a.toInt(); } else if (a.isDouble()) { out = -a.toDouble(); } else { throw std::runtime_error("unsupported type for symbolic input"); } executionFrame.setIValue(node_->outputs()[0]->id(), out); } else if (target == "_operator.truediv") { auto ia = inputs.at("a"); double a = ia.isInt() ? static_cast(ia.toInt()) : ia.toDouble(); auto ib = inputs.at("b"); double b = ib.isInt() ? static_cast(ib.toInt()) : ib.toDouble(); double out = a / b; executionFrame.setIValue(node_->outputs()[0]->id(), out); } else { throw std::runtime_error( fmt::format("unsupported operator for SymFloat: {}", node_->target())); } } } // namespace torch::nativert