Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion tmva/sofie/src/RModel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,30 @@ void RModel::GenerateOutput()
if (!doInferArgs.empty())
doInferArgs.back() = ' ';

// verifying if the dynamic parameters are within allowed range
std::unordered_set<std::string> input_params_checked;
std::string dynamic_parameters_check = "";
for (auto &name : fInputTensorNames) {
if (IsDimInputTensor(name)) {
auto shape = GetDynamicTensorShape(name);
for (auto &d : shape) {
std::string pName = d.param;
if (d.isParam && input_params_checked.count(pName) == 0) {
std::string cap = d.param;
if (!cap.empty()) {
cap[0] = std::toupper(static_cast<unsigned char>(cap[0]));
}
dynamic_parameters_check += d.param + " > f" + cap + " || ";
input_params_checked.insert(pName);
fGC += SP + "if (" + d.param + " > f" + cap + ") {\n";
fGC += SP + SP + "throw std::runtime_error(\"TMVA-SOFIE: dynamic input tensor shape parameter " +
d.param + " exceeds the initialized maximum allowed shape.\");\n";
fGC += SP + "}\n";
}
}
}
}

fGC += SP + "doInfer(" + doInferArgs + ");\n";

fGC += SP + "return {";
Expand Down Expand Up @@ -1068,7 +1092,18 @@ void RModel::GenerateSessionCode()
// generate code for declarations of some specific operators
GenerateOperatorDeclarations();


// storing the parameters for future checking to avoid mismatches
if (!fDimShapeNames.empty()) {
fGC += "\n\n";
std::sort(fDimShapeNames.begin(), fDimShapeNames.end());
for (const auto &p : fDimShapeNames) {
std::string cap = p;
if (!cap.empty()) {
cap[0] = std::toupper(static_cast<unsigned char>(cap[0]));
}
fGC += "size_t f" + cap + ";\n";
}
}

// add subgraph session
if (!fSubGraphs.empty()) fGC += "// subgraph sessions\n";
Expand Down Expand Up @@ -1115,6 +1150,19 @@ void RModel::GenerateSessionCode()
}
fGC += ") {\n";

// initializing dynamic parameters
if (!fDimShapeNames.empty()) {
fGC += "\n\n";
std::sort(fDimShapeNames.begin(), fDimShapeNames.end());
for (const auto &p : fDimShapeNames) {
std::string cap = p;
if (!cap.empty()) {
cap[0] = std::toupper(static_cast<unsigned char>(cap[0]));
}
fGC += " f" + cap + " = " + p + ";\n";
}
}

if (fUseWeightFile) {
fGC += "\n//--- reading weights from file\n";
ReadInitializedTensorsFromFile(fReadPos);
Expand Down
Loading