Skip to content

Commit 99389e9

Browse files
committed
enhanace KNITRO implementation
1 parent e586c3d commit 99389e9

File tree

5 files changed

+422
-68
lines changed

5 files changed

+422
-68
lines changed

include/pyoptinterface/knitro_model.hpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,14 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
475475
double get_mip_relative_gap() const;
476476
double get_solve_time() const;
477477

478+
// Model state
479+
bool dirty() const;
480+
bool empty() const;
481+
482+
// Solve status
483+
int get_solve_status() const;
484+
485+
// Parameter management
478486
template <typename T>
479487
void set_raw_parameter(const std::string &name, T value)
480488
{
@@ -539,31 +547,33 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
539547

540548
// Internal helpers
541549
void _check_error(int error) const;
550+
void _mark_dirty();
551+
void _unmark_dirty();
542552
void _check_dirty() const;
543553
KNINT _variable_index(const VariableIndex &variable) const;
544554
KNINT _constraint_index(const ConstraintIndex &constraint) const;
545555

546-
// Member variables
547-
std::shared_ptr<LM_context> m_lm = nullptr;
548-
std::unique_ptr<KN_context, KNITROFreeProblemT> m_kc = nullptr;
549-
550556
size_t n_vars = 0;
551557
size_t n_cons = 0;
552558
size_t n_lincons = 0;
553559
size_t n_quadcons = 0;
554560
size_t n_coniccons = 0;
555561
size_t n_nlcons = 0;
556562

563+
private:
564+
// Member variables
565+
std::shared_ptr<LM_context> m_lm = nullptr;
566+
std::unique_ptr<KN_context, KNITROFreeProblemT> m_kc = nullptr;
567+
557568
std::unordered_map<KNINT, std::variant<KNINT, std::pair<KNINT, KNINT>>> m_soc_aux_cons;
558569
std::unordered_map<KNINT, uint8_t> m_con_sense_flags;
559570
uint8_t m_obj_flag = 0;
560571

561572
std::unordered_map<ExpressionGraph *, Outputs> m_pending_outputs;
562573
std::vector<std::unique_ptr<CallbackEvaluator<double>>> m_evaluators;
563574
bool m_need_to_add_callbacks = false;
564-
565-
bool m_is_dirty = true;
566575
int m_solve_status = 0;
576+
bool m_is_dirty = true;
567577

568578
private:
569579
void _init();
@@ -590,19 +600,6 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
590600
void _solve();
591601
void _post_solve();
592602

593-
template <typename F>
594-
void _init_impl(const F &ctor)
595-
{
596-
if (!knitro::is_library_loaded())
597-
{
598-
throw std::runtime_error("KNITRO library not loaded");
599-
}
600-
KN_context *kc = nullptr;
601-
int error = ctor(&kc);
602-
_check_error(error);
603-
m_kc = std::unique_ptr<KN_context, KNITROFreeProblemT>(kc);
604-
}
605-
606603
template <typename F>
607604
ConstraintIndex _add_constraint_impl(ConstraintType type,
608605
const std::tuple<double, double> &interval,

lib/knitro_model.cpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ VariableIndex KNITROModel::add_variable(VariableDomain domain, double lb, double
184184
}
185185

186186
n_vars++;
187-
m_is_dirty = true;
187+
_mark_dirty();
188188

189189
return variable;
190190
}
@@ -267,7 +267,7 @@ void KNITROModel::set_variable_domain(const VariableIndex &variable, VariableDom
267267
_set_value<KNINT, double>(knitro::KN_set_var_upbnd, indexVar, ub);
268268
}
269269

270-
m_is_dirty = true;
270+
_mark_dirty();
271271
}
272272

273273
double KNITROModel::get_variable_rc(const VariableIndex &variable) const
@@ -285,7 +285,7 @@ void KNITROModel::delete_variable(const VariableIndex &variable)
285285
_set_value<KNINT, double>(knitro::KN_set_var_lobnd, indexVar, -get_infinity());
286286
_set_value<KNINT, double>(knitro::KN_set_var_upbnd, indexVar, get_infinity());
287287
n_vars--;
288-
m_is_dirty = true;
288+
_mark_dirty();
289289
}
290290

291291
std::string KNITROModel::pprint_variable(const VariableIndex &variable) const
@@ -535,7 +535,7 @@ void KNITROModel::delete_constraint(const ConstraintIndex &constraint)
535535
m_soc_aux_cons.erase(it);
536536
}
537537

538-
m_is_dirty = true;
538+
_mark_dirty();
539539
}
540540

541541
void KNITROModel::set_constraint_name(const ConstraintIndex &constraint, const std::string &name)
@@ -580,7 +580,7 @@ void KNITROModel::set_normalized_rhs(const ConstraintIndex &constraint, double r
580580
_set_value<KNINT, double>(knitro::KN_set_con_upbnd, indexCon, rhs);
581581
}
582582

583-
m_is_dirty = true;
583+
_mark_dirty();
584584
}
585585

586586
double KNITROModel::get_normalized_rhs(const ConstraintIndex &constraint) const
@@ -611,7 +611,7 @@ void KNITROModel::set_normalized_coefficient(const ConstraintIndex &constraint,
611611
_update();
612612
int error = knitro::KN_chg_con_linear_term(m_kc.get(), indexCon, indexVar, coefficient);
613613
_check_error(error);
614-
m_is_dirty = true;
614+
_mark_dirty();
615615
}
616616

617617
void KNITROModel::_set_linear_constraint(const ConstraintIndex &constraint,
@@ -708,7 +708,7 @@ void KNITROModel::set_objective_coefficient(const VariableIndex &variable, doubl
708708
_update();
709709
int error = knitro::KN_chg_obj_linear_term(m_kc.get(), indexVar, coefficient);
710710
_check_error(error);
711-
m_is_dirty = true;
711+
_mark_dirty();
712712
}
713713

714714
void KNITROModel::add_single_nl_objective(ExpressionGraph &graph, const ExpressionHandle &result)
@@ -719,7 +719,7 @@ void KNITROModel::add_single_nl_objective(ExpressionGraph &graph, const Expressi
719719
m_pending_outputs[&graph].obj_idxs.push_back(i);
720720
m_need_to_add_callbacks = true;
721721
m_obj_flag |= OBJ_NONLINEAR;
722-
m_is_dirty = true;
722+
_mark_dirty();
723723
}
724724

725725
void KNITROModel::set_obj_sense(ObjectiveSense sense)
@@ -927,7 +927,7 @@ void KNITROModel::optimize()
927927
_pre_solve();
928928
_solve();
929929
_post_solve();
930-
m_is_dirty = false;
930+
_unmark_dirty();
931931
}
932932

933933
// Solve information
@@ -963,10 +963,38 @@ double KNITROModel::get_solve_time() const
963963
return _get_value<double>(knitro::KN_get_solve_time_real);
964964
}
965965

966+
// Dirty state management
967+
void KNITROModel::_mark_dirty()
968+
{
969+
m_is_dirty = true;
970+
}
971+
972+
void KNITROModel::_unmark_dirty()
973+
{
974+
m_is_dirty = false;
975+
}
976+
977+
bool KNITROModel::dirty() const
978+
{
979+
return m_is_dirty;
980+
}
981+
982+
// Model state
983+
bool KNITROModel::empty() const
984+
{
985+
return m_kc == nullptr;
986+
}
987+
988+
int KNITROModel::get_solve_status() const
989+
{
990+
_check_dirty();
991+
return m_solve_status;
992+
}
993+
966994
// Internal helpers
967995
void KNITROModel::_check_dirty() const
968996
{
969-
if (m_is_dirty)
997+
if (dirty())
970998
{
971999
throw std::runtime_error("Model has been modified since last solve. Call optimize()...");
9721000
}
@@ -987,7 +1015,7 @@ void KNITROModel::_reset_state()
9871015
m_pending_outputs.clear();
9881016
m_evaluators.clear();
9891017
m_need_to_add_callbacks = false;
990-
m_is_dirty = true;
1018+
_mark_dirty();
9911019
m_solve_status = 0;
9921020
}
9931021

lib/knitro_model_ext.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,15 @@ NB_MODULE(knitro_model_ext, m)
232232
},
233233
nb::arg("param_id"))
234234

235-
.def_rw("m_is_dirty", &KNITROModel::m_is_dirty)
236-
.def_ro("m_solve_status", &KNITROModel::m_solve_status);
235+
// clang-format off
236+
BIND_F(dirty)
237+
BIND_F(empty)
238+
// clang-format on
239+
240+
// clang-format off
241+
BIND_F(get_solve_status)
242+
// clang-format on
243+
;
237244

238245
#undef BIND_F
239246
}

0 commit comments

Comments
 (0)