Skip to content

Commit e586c3d

Browse files
committed
Refactor and add license manager to KNITRO.
1 parent c1ea002 commit e586c3d

File tree

6 files changed

+387
-32
lines changed

6 files changed

+387
-32
lines changed

include/pyoptinterface/knitro_model.hpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
// Define Knitro C APIs to be dynamically loaded
1818
#define APILIST \
19+
B(KN_checkout_license); \
20+
B(KN_release_license); \
21+
B(KN_new_lm); \
1922
B(KN_new); \
2023
B(KN_free); \
2124
B(KN_update); \
@@ -103,6 +106,17 @@ struct KNITROFreeProblemT
103106
}
104107
};
105108

109+
struct KNITROFreeLicenseT
110+
{
111+
void operator()(LM_context *lmc) const
112+
{
113+
if (lmc != nullptr)
114+
{
115+
knitro::KN_release_license(&lmc);
116+
}
117+
}
118+
};
119+
106120
enum ObjectiveFlags
107121
{
108122
OBJ_CONSTANT = 1 << 0, // 0x01
@@ -334,6 +348,35 @@ inline ObjectiveSense knitro_obj_sense(int goal)
334348
}
335349
}
336350

351+
inline void knitro_throw(int error)
352+
{
353+
if (error != 0)
354+
{
355+
throw std::runtime_error(fmt::format("KNITRO error code: {}", error));
356+
}
357+
}
358+
359+
class KNITROEnv
360+
{
361+
public:
362+
KNITROEnv(bool empty = false);
363+
364+
KNITROEnv(const KNITROEnv &) = delete;
365+
KNITROEnv &operator=(const KNITROEnv &) = delete;
366+
367+
KNITROEnv(KNITROEnv &&) = default;
368+
KNITROEnv &operator=(KNITROEnv &&) = default;
369+
370+
void start();
371+
bool empty() const;
372+
std::shared_ptr<LM_context> get_lm() const;
373+
void close();
374+
375+
private:
376+
void _check_error(int code) const;
377+
std::shared_ptr<LM_context> m_lm = nullptr;
378+
};
379+
337380
class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
338381
public TwosideLinearConstraintMixin<KNITROModel>,
339382
public OnesideQuadraticConstraintMixin<KNITROModel>,
@@ -346,7 +389,16 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
346389
public:
347390
// Constructor/Init/Close
348391
KNITROModel();
392+
KNITROModel(const KNITROEnv &env);
393+
394+
KNITROModel(const KNITROModel &) = delete;
395+
KNITROModel &operator=(const KNITROModel &) = delete;
396+
397+
KNITROModel(KNITROModel &&) = default;
398+
KNITROModel &operator=(KNITROModel &&) = default;
399+
349400
void init();
401+
void init(const KNITROEnv &env);
350402
void close();
351403

352404
// Model information
@@ -492,6 +544,7 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
492544
KNINT _constraint_index(const ConstraintIndex &constraint) const;
493545

494546
// Member variables
547+
std::shared_ptr<LM_context> m_lm = nullptr;
495548
std::unique_ptr<KN_context, KNITROFreeProblemT> m_kc = nullptr;
496549

497550
size_t n_vars = 0;
@@ -513,6 +566,8 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
513566
int m_solve_status = 0;
514567

515568
private:
569+
void _init();
570+
void _reset_state();
516571
std::tuple<double, double> _sense_to_interval(ConstraintSense sense, double rhs);
517572
void _update_con_sense_flags(const ConstraintIndex &constraint, ConstraintSense sense);
518573

@@ -535,6 +590,19 @@ class KNITROModel : public OnesideLinearConstraintMixin<KNITROModel>,
535590
void _solve();
536591
void _post_solve();
537592

593+
template <typename F>
594+
void _init_impl(const F &ctor)
595+
{
596+
if (!knitro::is_library_loaded())
597+
{
598+
throw std::runtime_error("KNITRO library not loaded");
599+
}
600+
KN_context *kc = nullptr;
601+
int error = ctor(&kc);
602+
_check_error(error);
603+
m_kc = std::unique_ptr<KN_context, KNITROFreeProblemT>(kc);
604+
}
605+
538606
template <typename F>
539607
ConstraintIndex _add_constraint_impl(ConstraintType type,
540608
const std::tuple<double, double> &interval,

lib/knitro_model.cpp

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,36 +50,90 @@ bool load_library(const std::string &path)
5050
}
5151
} // namespace knitro
5252

53-
KNITROModel::KNITROModel()
53+
void ensure_library_loaded()
5454
{
55-
init();
55+
if (!knitro::is_library_loaded())
56+
{
57+
throw std::runtime_error("KNITRO library not loaded");
58+
}
5659
}
5760

58-
void KNITROModel::init()
61+
KNITROEnv::KNITROEnv(bool empty)
5962
{
60-
if (!knitro::is_library_loaded())
63+
if (!empty)
6164
{
62-
throw std::runtime_error("KNITRO library is not loaded");
65+
start();
6366
}
67+
}
6468

65-
KN_context *kc_ptr = nullptr;
66-
int error = knitro::KN_new(&kc_ptr);
69+
void KNITROEnv::start()
70+
{
71+
if (!empty())
72+
{
73+
return;
74+
}
75+
ensure_library_loaded();
76+
LM_context *lm = nullptr;
77+
int error = knitro::KN_checkout_license(&lm);
6778
_check_error(error);
79+
m_lm = std::shared_ptr<LM_context>(lm, KNITROFreeLicenseT());
80+
}
6881

69-
m_kc = std::unique_ptr<KN_context, KNITROFreeProblemT>(kc_ptr);
82+
bool KNITROEnv::empty() const
83+
{
84+
return m_lm == nullptr;
7085
}
7186

72-
void KNITROModel::close()
87+
std::shared_ptr<LM_context> KNITROEnv::get_lm() const
7388
{
74-
m_kc.reset();
89+
return m_lm;
90+
}
91+
92+
void KNITROEnv::close()
93+
{
94+
m_lm.reset();
95+
}
96+
97+
void KNITROEnv::_check_error(int code) const
98+
{
99+
knitro_throw(code);
100+
}
101+
102+
KNITROModel::KNITROModel()
103+
{
104+
init();
75105
}
76106

77-
void KNITROModel::_check_error(int error) const
107+
KNITROModel::KNITROModel(const KNITROEnv &env)
78108
{
79-
if (error != 0)
109+
init(env);
110+
}
111+
112+
void KNITROModel::init()
113+
{
114+
m_lm.reset();
115+
_init();
116+
}
117+
118+
void KNITROModel::init(const KNITROEnv &env)
119+
{
120+
if (env.empty())
80121
{
81-
throw std::runtime_error(fmt::format("KNITRO error code: {}", error));
122+
throw std::runtime_error("Empty environment provided. Call start()...");
82123
}
124+
m_lm = env.get_lm();
125+
_init();
126+
}
127+
128+
void KNITROModel::close()
129+
{
130+
_reset_state();
131+
m_lm.reset();
132+
}
133+
134+
void KNITROModel::_check_error(int code) const
135+
{
136+
knitro_throw(code);
83137
}
84138

85139
// Model information
@@ -918,6 +972,38 @@ void KNITROModel::_check_dirty() const
918972
}
919973
}
920974

975+
void KNITROModel::_reset_state()
976+
{
977+
m_kc.reset();
978+
n_vars = 0;
979+
n_cons = 0;
980+
n_lincons = 0;
981+
n_quadcons = 0;
982+
n_coniccons = 0;
983+
n_nlcons = 0;
984+
m_soc_aux_cons.clear();
985+
m_con_sense_flags.clear();
986+
m_obj_flag = 0;
987+
m_pending_outputs.clear();
988+
m_evaluators.clear();
989+
m_need_to_add_callbacks = false;
990+
m_is_dirty = true;
991+
m_solve_status = 0;
992+
}
993+
994+
void KNITROModel::_init()
995+
{
996+
ensure_library_loaded();
997+
_reset_state();
998+
999+
// Create new KNITRO problem
1000+
KN_context *kc = nullptr;
1001+
int error = m_lm ? knitro::KN_new_lm(m_lm.get(), &kc) : knitro::KN_new(&kc);
1002+
knitro_throw(error);
1003+
m_kc = std::unique_ptr<KN_context, KNITROFreeProblemT>(kc);
1004+
}
1005+
1006+
9211007
KNINT KNITROModel::_variable_index(const VariableIndex &variable) const
9221008
{
9231009
return _get_index(variable);

lib/knitro_model_ext.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,19 @@ NB_MODULE(knitro_model_ext, m)
1818

1919
bind_knitro_constants(m);
2020

21+
nb::class_<KNITROEnv>(m, "RawEnv")
22+
.def(nb::init<bool>(), nb::arg("empty") = false)
23+
.def("start", &KNITROEnv::start)
24+
.def("empty", &KNITROEnv::empty)
25+
.def("close", &KNITROEnv::close);
26+
2127
#define BIND_F(f) .def(#f, &KNITROModel::f)
2228
nb::class_<KNITROModel>(m, "RawModel")
2329
.def(nb::init<>())
24-
25-
// clang-format off
26-
BIND_F(init)
30+
.def(nb::init<const KNITROEnv &>())
31+
.def("init", nb::overload_cast<>(&KNITROModel::init))
32+
.def("init", nb::overload_cast<const KNITROEnv &>(&KNITROModel::init))
33+
// clang-format off
2734
BIND_F(close)
2835
BIND_F(get_infinity)
2936
BIND_F(get_number_iterations)

src/pyoptinterface/_src/knitro.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ScalarQuadraticFunction,
2424
VariableIndex,
2525
)
26-
from .knitro_model_ext import KN, RawModel, load_library
26+
from .knitro_model_ext import KN, RawEnv, RawModel, load_library
2727
from .matrix import add_matrix_constraints
2828
from .nlexpr_ext import ExpressionGraph, ExpressionHandle
2929
from .nlfunc import ExpressionGraphContext, convert_to_expressionhandle
@@ -256,15 +256,50 @@ def _result_status_knitro(model: "Model"):
256256
}
257257

258258

259+
class Env(RawEnv):
260+
"""
261+
KNITRO license manager environment.
262+
"""
263+
@property
264+
def is_empty(self):
265+
return self.empty()
266+
267+
259268
class Model(RawModel):
260269
"""
261270
KNITRO model class for PyOptInterface.
262271
"""
263272

264-
def __init__(self):
265-
super().__init__()
273+
def __init__(self, env: Env = None):
274+
if env is not None:
275+
super().__init__(env)
276+
else:
277+
super().__init__()
278+
self._reset_graph_map()
279+
280+
def _reset_graph_map(self):
266281
self.graph_map: dict[ExpressionGraph, int] = {}
267282

283+
def _add_graph_expr(self, expr: ExpressionHandle):
284+
graph = ExpressionGraphContext.current_graph()
285+
expr = convert_to_expressionhandle(graph, expr)
286+
if not isinstance(expr, ExpressionHandle):
287+
raise ValueError("Expression should be convertible to ExpressionHandle")
288+
if graph not in self.graph_map:
289+
self.graph_map[graph] = len(self.graph_map)
290+
return graph, expr
291+
292+
def init(self, env: Env = None):
293+
if env is not None:
294+
super().init(env)
295+
else:
296+
super().init()
297+
self._reset_graph_map()
298+
299+
def close(self):
300+
super().close()
301+
self._reset_graph_map()
302+
268303
@staticmethod
269304
def supports_variable_attribute(
270305
attribute: VariableAttribute, setable: bool = False
@@ -395,21 +430,11 @@ def add_nl_constraint(
395430
): ...
396431

397432
def add_nl_constraint(self, expr, *args, **kwargs):
398-
graph = ExpressionGraphContext.current_graph()
399-
expr = convert_to_expressionhandle(graph, expr)
400-
if not isinstance(expr, ExpressionHandle):
401-
raise ValueError("Expression should be convertible to ExpressionHandle")
402-
if graph not in self.graph_map:
403-
self.graph_map[graph] = len(self.graph_map)
433+
graph, expr = self._add_graph_expr(expr)
404434
return self._add_single_nl_constraint(graph, expr, *args, **kwargs)
405435

406436
def add_nl_objective(self, expr):
407-
graph = ExpressionGraphContext.current_graph()
408-
expr = convert_to_expressionhandle(graph, expr)
409-
if not isinstance(expr, ExpressionHandle):
410-
raise ValueError("Expression should be convertible to ExpressionHandle")
411-
if graph not in self.graph_map:
412-
self.graph_map[graph] = len(self.graph_map)
437+
graph, expr = self._add_graph_expr(expr)
413438
self._add_single_nl_objective(graph, expr)
414439

415440
def get_model_attribute(self, attr: ModelAttribute):

src/pyoptinterface/knitro.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from pyoptinterface._src.knitro import Model, autoload_library
1+
from pyoptinterface._src.knitro import Env, Model, autoload_library
22
from pyoptinterface._src.knitro_model_ext import (
33
KN,
44
load_library,
55
is_library_loaded,
66
)
77

88
__all__ = [
9+
"Env",
910
"Model",
1011
"KN",
1112
"autoload_library",

0 commit comments

Comments
 (0)