diff --git a/tmva/sofie/src/RModel.cxx b/tmva/sofie/src/RModel.cxx index 9caef5ab9da50..6195d93528104 100644 --- a/tmva/sofie/src/RModel.cxx +++ b/tmva/sofie/src/RModel.cxx @@ -1001,6 +1001,30 @@ void RModel::GenerateOutput() if (!doInferArgs.empty()) doInferArgs.back() = ' '; + // verifying if the dynamic parameters are within allowed range + std::unordered_set input_params_checked; + std::string dynamic_parameters_check = ""; + for (auto &name : fInputTensorNames) { + if (IsDimInputTensor(name)) { + auto shape = GetDynamicTensorShape(name); + for (auto &d : shape) { + std::string pName = d.param; + if (d.isParam && input_params_checked.count(pName) == 0) { + std::string cap = d.param; + if (!cap.empty()) { + cap[0] = std::toupper(static_cast(cap[0])); + } + dynamic_parameters_check += d.param + " > f" + cap + " || "; + input_params_checked.insert(pName); + fGC += SP + "if (" + d.param + " > f" + cap + ") {\n"; + fGC += SP + SP + "throw std::runtime_error(\"TMVA-SOFIE: dynamic input tensor shape parameter " + + d.param + " exceeds the initialized maximum allowed shape.\");\n"; + fGC += SP + "}\n"; + } + } + } + } + fGC += SP + "doInfer(" + doInferArgs + ");\n"; fGC += SP + "return {"; @@ -1068,7 +1092,18 @@ void RModel::GenerateSessionCode() // generate code for declarations of some specific operators GenerateOperatorDeclarations(); - + // storing the parameters for future checking to avoid mismatches + if (!fDimShapeNames.empty()) { + fGC += "\n\n"; + std::sort(fDimShapeNames.begin(), fDimShapeNames.end()); + for (const auto &p : fDimShapeNames) { + std::string cap = p; + if (!cap.empty()) { + cap[0] = std::toupper(static_cast(cap[0])); + } + fGC += "size_t f" + cap + ";\n"; + } + } // add subgraph session if (!fSubGraphs.empty()) fGC += "// subgraph sessions\n"; @@ -1115,6 +1150,19 @@ void RModel::GenerateSessionCode() } fGC += ") {\n"; + // initializing dynamic parameters + if (!fDimShapeNames.empty()) { + fGC += "\n\n"; + std::sort(fDimShapeNames.begin(), fDimShapeNames.end()); + for (const auto &p : fDimShapeNames) { + std::string cap = p; + if (!cap.empty()) { + cap[0] = std::toupper(static_cast(cap[0])); + } + fGC += " f" + cap + " = " + p + ";\n"; + } + } + if (fUseWeightFile) { fGC += "\n//--- reading weights from file\n"; ReadInitializedTensorsFromFile(fReadPos);