Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ SOURCE_FILES = \
HexagonOffload.cpp \
HexagonOptimize.cpp \
ImageParam.cpp \
Inductive.cpp \
InferArguments.cpp \
InjectHostDevBufferCopies.cpp \
Inline.cpp \
Expand Down Expand Up @@ -720,6 +721,7 @@ HEADER_FILES = \
HexagonOffload.h \
HexagonOptimize.h \
ImageParam.h \
Inductive.h \
InferArguments.h \
InjectHostDevBufferCopies.h \
Inline.h \
Expand Down
9 changes: 0 additions & 9 deletions python_bindings/test/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,6 @@ def test_unevaluated_funcref():
f[x] += 1
assert f.realize([1])[0] == 1

with assert_throws(
hl.HalideError,
r"Error: Can't call Func \"f(\$\d+)?\" because it has not yet been defined\.",
):
# This is invalid because we only allow unevaluated func refs on the LHS of a
# binary operator.
f = hl.Func("f")
f[x] = 1 + f[x]

with assert_throws(
hl.HalideError,
r"Cannot use an unevaluated reference to 'f(\$\d+)?' to define an update at a different location\.",
Expand Down
2 changes: 1 addition & 1 deletion src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3299,7 +3299,7 @@ FuncValueBounds compute_function_value_bounds(const vector<string> &order,

Interval result;

if (f.is_pure()) {
if (f.is_pure() && !f.is_inductive()) {

// Make a scope that says the args could be anything.
Scope<Interval> arg_scope;
Expand Down
16 changes: 16 additions & 0 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "IREquality.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Inductive.h"
#include "Inline.h"
#include "Qualify.h"
#include "Scope.h"
Expand Down Expand Up @@ -1021,6 +1022,21 @@ class BoundsInference : public IRMutator {
}
}

// For any inductively defined functions, make sure their
// bounds include the base case.
for (Stage &s : stages) {
if (!s.func.is_pure() || !s.func.is_inductive()) {
continue;
}
debug(4) << "Expanding bounds for inductively defined function " << s.func.name() << "\n";
for (const auto &b1 : s.bounds) {
const Box &b = b1.second;
for (const auto &cval : s.exprs) {
s.bounds[b1.first] = expand_to_include_base_case(s.func.args(), cval.value, s.func.name(), b);
}
}
}

// The region required of the each output is expanded to include the size of the output buffer.
for (const Function &output : outputs) {
Box output_box;
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ target_sources(
HexagonOffload.h
HexagonOptimize.h
ImageParam.h
Inductive.h
InferArguments.h
InjectHostDevBufferCopies.h
Inline.h
Expand Down Expand Up @@ -305,6 +306,7 @@ target_sources(
HexagonOffload.cpp
HexagonOptimize.cpp
ImageParam.cpp
Inductive.cpp
InferArguments.cpp
InjectHostDevBufferCopies.cpp
Inline.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ DimType Deserializer::deserialize_dim_type(Serialize::DimType dim_type) {
return DimType::PureRVar;
case Serialize::DimType::ImpureRVar:
return DimType::ImpureRVar;
case Serialize::DimType::InductiveVar:
return DimType::InductiveVar;
default:
user_error << "unknown dim type " << (int)dim_type << "\n";
return DimType::PureVar;
Expand Down
20 changes: 19 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ Func::Func(const string &name)
: func(unique_name(name)) {
}

Func::Func(const Type &required_type, const string &name)
: func({required_type}, AnyDims, unique_name(name)) {
}

Func::Func(const Type &required_type, int required_dims, const string &name)
: func({required_type}, required_dims, unique_name(name)) {
}
Expand Down Expand Up @@ -491,7 +495,7 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) {
// If it's an rvar and the for type is parallel, we need to
// validate that this doesn't introduce a race condition,
// unless it is flagged explicitly or is a associative atomic operation.
if (!dim.is_pure() && var.is_rvar && is_parallel(t)) {
if (!dim.is_pure() && (var.is_rvar || dim.is_inductive()) && is_parallel(t)) {
if (!definition.schedule().allow_race_conditions() &&
definition.schedule().atomic()) {
if (!definition.schedule().override_atomic_associativity_test()) {
Expand Down Expand Up @@ -1342,6 +1346,9 @@ Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRV
} else if (dims[i].dim_type == DimType::PureRVar ||
outer_type == DimType::PureRVar) {
dims[i].dim_type = DimType::PureRVar;
} else if (dims[i].dim_type == DimType::InductiveVar ||
outer_type == DimType::InductiveVar) {
dims[i].dim_type = DimType::InductiveVar;
} else {
dims[i].dim_type = DimType::PureVar;
}
Expand Down Expand Up @@ -3304,8 +3311,19 @@ Stage FuncRef::operator/=(const FuncRef &e) {
}

FuncRef::operator Expr() const {
/*
user_assert(func.has_pure_definition() || func.has_extern_definition())
<< "Can't call Func \"" << func.name() << "\" because it has not yet been defined.\n";
*/

if (!(func.has_pure_definition() || func.has_extern_definition())) {
Type t = Type(Type::Unknown, 0, 1);
if (!func.required_types().empty()) {
t = func.required_types()[0];
}
return Call::make(t, func.name(), args, Call::Halide,
func.get_contents(), 0, Buffer<>(), Parameter());
}

user_assert(func.outputs() == 1)
<< "Can't convert a reference Func \"" << func.name()
Expand Down
4 changes: 4 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,10 @@ class Func {
/** Declare a new undefined function with the given name */
explicit Func(const std::string &name);

/** Declare a new undefined function with the given name.
* The function will be constrained to represent Exprs of required_type. */
explicit Func(const Type &required_type, const std::string &name);

/** Declare a new undefined function with the given name.
* The function will be constrained to represent Exprs of required_type.
* If required_dims is not AnyDims, the function will be constrained to exactly
Expand Down
155 changes: 140 additions & 15 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ struct CheckVars : public IRGraphVisitor {
Scope<> defined_internally;
const std::string name;
bool unbound_reduction_vars_ok = false;
bool pure;

CheckVars(const std::string &n)
: name(n) {
CheckVars(const std::string &n, bool pure)
: name(n), pure(pure) {
}

using IRVisitor::visit;
Expand All @@ -222,15 +223,34 @@ struct CheckVars : public IRGraphVisitor {

void visit(const Call *op) override {
IRGraphVisitor::visit(op);
if (op->name == name && op->call_type == Call::Halide) {
for (size_t i = 0; i < op->args.size(); i++) {
const Variable *var = op->args[i].as<Variable>();
if (!pure_args[i].empty()) {
user_assert(var && var->name == pure_args[i])
<< "In definition of Func \"" << name << "\":\n"
<< "All of a function's recursive references to itself"
<< " must contain the same pure variables in the same"
<< " places as on the left-hand-side.\n";
bool proper_func = (name == op->name || op->call_type != Call::Halide);
if (op->call_type == Call::Halide &&
op->func.defined() &&
op->name != name) {
Function func = Function(op->func);
proper_func = (name == op->name || func.has_pure_definition() || func.has_extern_definition());
}
if (pure) {
user_assert(proper_func)
<< "In pure definition of Func \"" << name << "\":\n"
<< "Can't call Func \"" << op->name
<< "\" because it has not yet been defined,"
<< " and it is not a recursive call.\n";
} else {
user_assert(proper_func)
<< "In update definition of Func \"" << name << "\":\n"
<< "Can't call Func \"" << op->name
<< "\" because it has not yet been defined.\n";
if (op->name == name && op->call_type == Call::Halide) {
for (size_t i = 0; i < op->args.size(); i++) {
const Variable *var = op->args[i].as<Variable>();
if (!pure_args[i].empty()) {
user_assert(var && var->name == pure_args[i])
<< "In update definition of Func \"" << name << "\":\n"
<< "All of a function's recursive references to itself"
<< " in update definitions must contain the same pure"
<< " variables in the same places as on the left-hand-side.\n";
}
}
}
}
Expand Down Expand Up @@ -562,14 +582,15 @@ void Function::define(const vector<string> &args, vector<Expr> values) {

// Make sure all the vars in the value are either args or are
// attached to some parameter
CheckVars check(name());
CheckVars check(name(), true);
check.pure_args = args;
for (const auto &value : values) {
value.accept(&check);
}

// Freeze all called functions
FreezeFunctions freezer(name());
// TODO: Check for calls to undefined Funcs
for (const auto &value : values) {
value.accept(&freezer);
}
Expand Down Expand Up @@ -629,11 +650,31 @@ void Function::define(const vector<string> &args, vector<Expr> values) {
init_def_args[i] = Var(args[i]);
}

// If the function is inductive,
// the value and args might refer back to the
// function itself, introducing circular references and hence
// memory leaks. We need to break these cycles.
WeakenFunctionPtrs weakener(contents.get());
for (auto &arg : init_def_args) {
arg = weakener.mutate(arg);
}
for (auto &value : values) {
value = weakener.mutate(value);
}
if (check.reduction_domain.defined()) {
check.reduction_domain.set_predicate(
weakener.mutate(check.reduction_domain.predicate()));
}

ReductionDomain rdom;
contents->init_def = Definition(init_def_args, values, rdom, true);

for (const auto &arg : args) {
Dim d = {arg, ForType::Serial, DeviceAPI::None, DimType::PureVar};
DimType dtype = DimType::PureVar;
if (is_inductive(arg)) {
dtype = DimType::InductiveVar;
}
Dim d = {arg, ForType::Serial, DeviceAPI::None, dtype};
contents->init_def.schedule().dims().push_back(d);
StorageDim sd = {arg};
contents->func_schedule.storage_dims().push_back(sd);
Expand Down Expand Up @@ -689,6 +730,9 @@ void Function::define_update(const vector<Expr> &_args, vector<Expr> values, con
user_assert(!frozen())
<< "Func " << name() << " cannot be given a new update definition, "
<< "because it has already been realized or used in the definition of another Func.\n";
user_assert(!is_inductive())
<< "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
<< "Inductive functions cannot have update definitions.\n";

for (auto &value : values) {
user_assert(value.defined())
Expand Down Expand Up @@ -759,7 +803,7 @@ void Function::define_update(const vector<Expr> &_args, vector<Expr> values, con
// pure args, in the reduction domain, or a parameter. Also checks
// that recursive references to the function contain all the pure
// vars in the LHS in the correct places.
CheckVars check(name());
CheckVars check(name(), false);
check.pure_args = pure_args;
for (const auto &arg : args) {
arg.accept(&check);
Expand Down Expand Up @@ -1066,8 +1110,89 @@ bool Function::has_pure_definition() const {
return contents->init_def.defined();
}

bool Function::is_inductive() const {
class RecursiveHelper : public IRVisitor {
using IRVisitor::visit;
const string &func;
void visit(const Call *op) override {
if (op->name == func) {
recursive = true;
}
IRVisitor::visit(op);
}

public:
bool recursive = false;
RecursiveHelper(const string &func)
: func(func) {
}
};

if (!has_pure_definition()) {
return false;
}

RecursiveHelper r(name());
for (const Expr &e : definition().values()) {
e.accept(&r);
}

return r.recursive;
}

bool Function::is_inductive(const string &var) const {
class RecursiveHelper : public IRVisitor {
using IRVisitor::visit;
const string &func;
const string &var;
const int &pos;
void visit(const Call *op) override {
if (op->name == func) {
recursive = true;
if (const auto &v = op->args[pos].as<Variable>()) {
if (v->name != var) {
inductive_in_var = true;
}
} else {
inductive_in_var = true;
}
}
IRVisitor::visit(op);
}

public:
bool recursive = false;
bool inductive_in_var = false;
RecursiveHelper(const string &func, const string &var, const int &pos)
: func(func), var(var), pos(pos) {
}
};

if (!has_pure_definition()) {
return false;
}

int pos = -1;
for (size_t i = 0; i < definition().args().size(); i++) {
if (const auto &v = definition().args()[i].as<Variable>()) {
if (v->name == var) {
pos = i;
}
}
}
if (pos == -1) {
return false;
}
RecursiveHelper r(name(), var, pos);
for (const Expr &e : definition().values()) {
e.accept(&r);
}

return r.inductive_in_var;
}

bool Function::can_be_inlined() const {
return is_pure() && definition().specializations().empty();
return is_pure() && definition().specializations().empty() && !is_inductive();
}

bool Function::has_update_definition() const {
Expand Down
6 changes: 6 additions & 0 deletions src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ class Function {
!has_extern_definition());
}

/** Does this function have an inductive pure definition? */
bool is_inductive() const;

/** Is this function inductive in the given variable? */
bool is_inductive(const std::string &var) const;

/** Is it legal to inline this function? */
bool can_be_inlined() const;

Expand Down
Loading