diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c98c989..98afe66 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | diff --git a/orangecontrib/explain/explainer.py b/orangecontrib/explain/explainer.py index 568ff51..8a084c2 100644 --- a/orangecontrib/explain/explainer.py +++ b/orangecontrib/explain/explainer.py @@ -60,12 +60,17 @@ def _join_shap_values( of lists with the np.ndarray for each class, when explaining regression, the result is the list of one np.ndarrays. """ - if isinstance(shap_values[0], np.ndarray): - # regression + shape = shap_values[0].shape + if len(shape) == 1 or (len(shape) == 2 and shape[0] == 1): + # regression and xgb with two classes return [np.vstack(shap_values)] else: # classification - return [np.vstack(s) for s in zip(*shap_values)] + if len(shape) == 3: + transformed = [(np.squeeze(v, axis=0)).T for v in shap_values] + else: + transformed = [v.T for v in shap_values] + return [np.vstack(s) for s in zip(*transformed)] def _explain_trees( diff --git a/orangecontrib/explain/inspection.py b/orangecontrib/explain/inspection.py index 3af0263..2a80e43 100644 --- a/orangecontrib/explain/inspection.py +++ b/orangecontrib/explain/inspection.py @@ -4,6 +4,7 @@ import numpy as np import scipy.sparse as sp from sklearn.inspection import partial_dependence +from sklearn.utils import Tags, TargetTags from Orange.base import Model from Orange.classification import Model as ClsModel @@ -202,11 +203,13 @@ def dummy_fit(*_, **__): model.fit = dummy_fit model.fit_ = dummy_fit if model.domain.class_var.is_discrete: - model._estimator_type = "classifier" model.classes_ = np.array(model.domain.class_var.values) + estimator_type = "classifier" else: - model._estimator_type = "regressor" + estimator_type = "regressor" + model.__sklearn_tags__ = lambda: Tags(estimator_type=estimator_type, + target_tags=TargetTags(required=True)) progress_callback(0.1) dep = partial_dependence(model, diff --git a/orangecontrib/explain/widgets/owice.py b/orangecontrib/explain/widgets/owice.py index 5e68bb5..3c8c0e7 100644 --- a/orangecontrib/explain/widgets/owice.py +++ b/orangecontrib/explain/widgets/owice.py @@ -26,6 +26,7 @@ create_annotated_table from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin from Orange.widgets.utils.itemmodels import VariableListModel, DomainModel +from Orange.widgets.utils.multi_target import check_multiple_targets_input from Orange.widgets.utils.sql import check_sql_input from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.visualize.owdistributions import LegendItem @@ -37,10 +38,6 @@ from orangecontrib.explain.inspection import individual_condition_expectation from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog -try: - from Orange.widgets.utils.multi_target import check_multiple_targets_input -except ImportError: - check_multiple_targets_input = lambda f: f class RunnerResults(SimpleNamespace): @@ -734,7 +731,7 @@ def _apply_feature_sorting(self): if self.order_by_importance: def compute_score(feature): values = self.__results_avgs[feature][self.target_index] - return -np.sum(np.abs(values - np.mean(values))) + return float(-np.sum(np.abs(values - np.mean(values)))) try: if self.__results_avgs is None: diff --git a/orangecontrib/explain/widgets/tests/test_owpermutationimportance.py b/orangecontrib/explain/widgets/tests/test_owpermutationimportance.py index 6e9945b..86e2699 100644 --- a/orangecontrib/explain/widgets/tests/test_owpermutationimportance.py +++ b/orangecontrib/explain/widgets/tests/test_owpermutationimportance.py @@ -281,19 +281,19 @@ def test_x_label(self): self.send_signal(self.widget.Inputs.model, self.rf_cls) self.wait_until_finished() label: QGraphicsTextItem = self.widget.plot.bottom_axis.label - self.assertEqual(label.toPlainText(), "Decrease in AUC ") + self.assertIn("Decrease in AUC", label.toPlainText()) self.send_signal(self.widget.Inputs.data, self.housing) self.send_signal(self.widget.Inputs.model, self.rf_reg) self.wait_until_finished() label: QGraphicsTextItem = self.widget.plot.bottom_axis.label - self.assertEqual(label.toPlainText(), "Decrease in R2 ") + self.assertIn("Decrease in R2", label.toPlainText()) score_cb: QComboBox = self.widget._score_combo simulate.combobox_activate_item(score_cb, "MSE") self.wait_until_finished() label: QGraphicsTextItem = self.widget.plot.bottom_axis.label - self.assertEqual(label.toPlainText(), "Increase in MSE ") + self.assertIn("Increase in MSE", label.toPlainText()) @unittest.mock.patch("orangecontrib.explain.widgets." "owpermutationimportance.OWPermutationImportance.run") diff --git a/setup.py b/setup.py index 7645712..226a5e0 100644 --- a/setup.py +++ b/setup.py @@ -38,18 +38,16 @@ ] INSTALL_REQUIRES = [ - "AnyQt", - # shap's requirement, force users for numba to get updated because compatibility - # issues with numpy - completely remove this pin after october 2024 - "numba >=0.58", - "numpy", - "Orange3 >=3.36.2", - "orange-canvas-core >=0.1.30", - "orange-widget-base >=4.22.0", - "pyqtgraph", - "scipy", - "shap==0.42.1", - "scikit-learn>=1.3.0", + "AnyQt>=0.2.0", + "Orange3>=3.39.0", + "orange-canvas-core>=0.2.5", + "orange-widget-base>=4.25.0", + "pandas>=2.2.2", + "scikit-learn>=1.7.0", + "scipy>=1.13.0", + "pyqtgraph>=0.13.1", + "numpy>=2.0.0", + "shap>=0.50.0", ] EXTRAS_REQUIRE = { diff --git a/tox.ini b/tox.ini index c8e7f24..d12e9c2 100644 --- a/tox.ini +++ b/tox.ini @@ -22,13 +22,16 @@ deps = {env:PYQT_PYPI_NAME:PyQt5}=={env:PYQT_PYPI_VERSION:5.15.*} {env:WEBENGINE_PYPI_NAME:PyQtWebEngine}=={env:WEBENGINE_PYPI_VERSION:5.15.*} xgboost - oldest: orange3==3.36.2 - oldest: orange-canvas-core==0.1.30 - oldest: orange-widget-base==4.22.0 - oldest: pandas==1.4.0 - oldest: scikit-learn==1.3.0 - oldest: scipy==1.9.0 + oldest: orange3==3.39.0 + oldest: orange-canvas-core==0.2.5 + oldest: orange-widget-base==4.25.0 + oldest: pandas~=2.2.2 + oldest: scikit-learn~=1.7.0 + oldest: scipy~=1.13.0 oldest: xgboost==2.0.0 + oldest: pyqtgraph==0.13.1 + oldest: numpy~=2.0.0 + oldest: shap==0.50.0 latest: https://github.com/biolab/orange3/archive/refs/heads/master.zip#egg=orange3 latest: https://github.com/biolab/orange-canvas-core/archive/refs/heads/master.zip#egg=orange-canvas-core latest: https://github.com/biolab/orange-widget-base/archive/refs/heads/master.zip#egg=orange-widget-base