Skip to content

Commit e8284e7

Browse files
authored
Merge pull request #73 from eminyouskn/cppad-obj-agg
[CppAD] Enable Multi-Output Objective Trace Functions
2 parents 0d4bd17 + ec1b77a commit e8284e7

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

include/pyoptinterface/cppad_interface.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ADFunDouble sparse_hessian(const ADFunDouble &f, const sparsity_pattern_t &patte
3333

3434
// Transform ExpressionGraph to CppAD function
3535
ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph);
36-
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph);
36+
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true);
3737

3838
struct CppADAutodiffGraph
3939
{

lib/cppad_interface.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph)
469469
return f;
470470
}
471471

472-
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph)
472+
ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate)
473473
{
474474
ankerl::unordered_dense::map<ExpressionHandle, CppAD::AD<double>> seen_expressions;
475475

@@ -503,15 +503,22 @@ ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph)
503503
y[i] = cppad_trace_expression(graph, output, x, p, seen_expressions);
504504
}
505505

506-
CppAD::AD<double> y_sum = 0.0;
507-
for (size_t i = 0; i < N_outputs; i++)
506+
ADFunDouble f;
507+
508+
if (aggregate)
509+
{
510+
CppAD::AD<double> y_sum = 0.0;
511+
for (size_t i = 0; i < N_outputs; i++)
512+
{
513+
y_sum += y[i];
514+
}
515+
f.Dependent(x, {y_sum});
516+
}
517+
else
508518
{
509-
y_sum += y[i];
519+
f.Dependent(x, y);
510520
}
511521

512-
ADFunDouble f;
513-
f.Dependent(x, {y_sum});
514-
515522
return f;
516523
}
517524

lib/cppad_interface_ext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,6 @@ NB_MODULE(cppad_interface_ext, m)
183183
.def_ro("hessian", &CppADAutodiffGraph::hessian_graph);
184184

185185
m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints);
186-
m.def("cppad_trace_graph_objective", cppad_trace_graph_objective);
186+
m.def("cppad_trace_graph_objective", cppad_trace_graph_objective, nb::arg("graph"), nb::arg("aggregate") = true);
187187
m.def("cppad_autodiff", &cppad_autodiff);
188188
}

lib/knitro_model.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,9 @@ void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs
806806
evaluator->eval_hess(req->x, req->sigma, res->hess, true);
807807
return 0;
808808
};
809-
auto trace = cppad_trace_graph_objective;
809+
auto trace = [](const ExpressionGraph &graph) {
810+
return cppad_trace_graph_objective(graph, false);
811+
};
810812
_add_callback_impl(*graph, outputs.obj_idxs, {}, trace, f, g, h);
811813
}
812814

0 commit comments

Comments
 (0)