Skip to content

Commit c1ea002

Browse files
authored
[KNITRO] Use reduced hessian pattern for hessian evaluations (#76)
* feat: add symmetric Hessian pattern support. * use refs to avoid unnecessary copy
1 parent 6866290 commit c1ea002

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

include/pyoptinterface/knitro_model.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ struct CallbackEvaluator
142142
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> jac_;
143143
CppAD::sparse_jac_work jac_work_;
144144
CppAD::sparse_rc<std::vector<size_t>> hess_pattern_;
145+
CppAD::sparse_rc<std::vector<size_t>> hess_pattern_symm_;
145146
CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>> hess_;
146147
CppAD::sparse_hes_work hess_work_;
147148

@@ -171,10 +172,19 @@ struct CallbackEvaluator
171172
select_rows[fun_rows[k]] = true;
172173
}
173174
fun.rev_hes_sparsity(select_rows, false, true, hess_pattern_);
175+
for (size_t k = 0; k < hess_pattern_.nnz(); k++)
176+
{
177+
size_t row = hess_pattern_.row()[k];
178+
size_t col = hess_pattern_.col()[k];
179+
if (row <= col)
180+
{
181+
hess_pattern_symm_.push_back(row, col);
182+
}
183+
}
174184
x.resize(fun.Domain(), 0.0);
175185
w.resize(fun.Range(), 0.0);
176186
jac_ = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(jac_pattern_);
177-
hess_ = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(hess_pattern_);
187+
hess_ = CppAD::sparse_rcv<std::vector<size_t>, std::vector<V>>(hess_pattern_symm_);
178188
}
179189

180190
void eval_fun(const V *req_x, V *res_y, bool aggregate = false)
@@ -204,7 +214,7 @@ struct CallbackEvaluator
204214
x[i] = req_x[indexVars[i]];
205215
}
206216
fun.sparse_jac_rev(x, jac_, jac_pattern_, jac_coloring_, jac_work_);
207-
auto jac = jac_.val();
217+
auto& jac = jac_.val();
208218
for (size_t i = 0; i < jac_.nnz(); i++)
209219
{
210220
res_jac[i] = jac[i];
@@ -229,7 +239,7 @@ struct CallbackEvaluator
229239
}
230240
}
231241
fun.sparse_hes(x, w, hess_, hess_pattern_, hess_coloring_, hess_work_);
232-
auto hess = hess_.val();
242+
auto& hess = hess_.val();
233243
for (size_t i = 0; i < hess_.nnz(); i++)
234244
{
235245
res_hess[i] = hess[i];
@@ -241,8 +251,8 @@ struct CallbackEvaluator
241251
CallbackPattern pattern;
242252
pattern.indexCons = indexCons;
243253

244-
auto jac_rows = jac_pattern_.row();
245-
auto jac_cols = jac_pattern_.col();
254+
auto& jac_rows = jac_pattern_.row();
255+
auto& jac_cols = jac_pattern_.col();
246256
if (indexCons.empty())
247257
{
248258
for (size_t k = 0; k < jac_pattern_.nnz(); k++)
@@ -259,9 +269,9 @@ struct CallbackEvaluator
259269
}
260270
}
261271

262-
auto hess_rows = hess_pattern_.row();
263-
auto hess_cols = hess_pattern_.col();
264-
for (size_t k = 0; k < hess_pattern_.nnz(); k++)
272+
auto& hess_rows = hess_pattern_symm_.row();
273+
auto& hess_cols = hess_pattern_symm_.col();
274+
for (size_t k = 0; k < hess_pattern_symm_.nnz(); k++)
265275
{
266276
pattern.hessIndexVars1.push_back(indexVars[hess_rows[k]]);
267277
pattern.hessIndexVars2.push_back(indexVars[hess_cols[k]]);

0 commit comments

Comments
 (0)