diff --git a/bn_modeller/app.py b/bn_modeller/app.py index 0b6773c..7c40177 100644 --- a/bn_modeller/app.py +++ b/bn_modeller/app.py @@ -10,6 +10,15 @@ def create_application(argv: list[str]) -> QApplication: + """ + Create a QApplication instance. + + Args: + argv (list[str]): List of command-line arguments. + + Returns: + QApplication: A new QApplication instance. + """ QCoreApplication.setOrganizationName("Digiratory") QCoreApplication.setOrganizationDomain("digiratory.ru") QCoreApplication.setApplicationName("bn_modeller") @@ -25,6 +34,15 @@ def create_application(argv: list[str]) -> QApplication: def manage_cli_args(app: QCoreApplication) -> QCommandLineParser: + """ + Manage command-line arguments. + + Args: + app (QCoreApplication): QCoreApplication instance. + + Returns: + QCommandLineParser: A new QCommandLineParser instance. + """ cl_parser: QCommandLineParser = QCommandLineParser() cl_parser.setApplicationDescription("Bayesian Network Modeller") cl_parser.addHelpOption() @@ -37,6 +55,16 @@ def manage_cli_args(app: QCoreApplication) -> QCommandLineParser: if __name__ == "__main__": + """ + Entry point of the application. + Args: + app (QCoreApplication): QCoreApplication instance. + cl_parser (QCommandLineParser): A new QCommandLineParser instance. + + Raises: + ValueError: If headless mode is not supported. + + """ app = create_application(sys.argv) cl_parser = manage_cli_args(app) diff --git a/bn_modeller/bayesian_nets/graph_preparation.py b/bn_modeller/bayesian_nets/graph_preparation.py index a152f78..e69b573 100644 --- a/bn_modeller/bayesian_nets/graph_preparation.py +++ b/bn_modeller/bayesian_nets/graph_preparation.py @@ -1,3 +1,5 @@ +from typing import Callable, Tuple + import networkx as nx @@ -40,7 +42,7 @@ def __init__(self, corr_matrix, table_connection, threshold=0): self.G.add_weighted_edges_from(e) - def drop_cycle(self): + def drop_cycle(self, onRemoveEdgeHandler: Callable[[Tuple[int, int]], None] = None): while True: try: cycle_list = nx.find_cycle(self.G) @@ -57,6 +59,8 @@ def drop_cycle(self): min_nodes = i self.G.remove_edge(*min_nodes) + if onRemoveEdgeHandler is not None: + onRemoveEdgeHandler(min_nodes) def getNodeList(self): return self.G.nodes @@ -64,10 +68,9 @@ def getNodeList(self): def getEdgeList(self): return self.G.edges - def getGraph(self): + def getInternalGraph(self): return self.G - def renaming(self): + def getGraph(self) -> nx.Graph: k = {val: key for key, val in self.code_columns.items()} - return nx.relabel_nodes(self.G, k) diff --git a/bn_modeller/bayesian_nets/pyBansheeCalculation.py b/bn_modeller/bayesian_nets/pyBansheeCalculation.py index 2cb8ec9..c32e46b 100644 --- a/bn_modeller/bayesian_nets/pyBansheeCalculation.py +++ b/bn_modeller/bayesian_nets/pyBansheeCalculation.py @@ -47,6 +47,8 @@ def getRankCorr(self): return self.R def getInference(self, len_input_list): + # TODO: Rewrite this function to handle set of features and features to predoct as input. + nodes = list( range(len_input_list) ) # all variables except for value of interest diff --git a/bn_modeller/button_page.py b/bn_modeller/button_page.py index 6cd0bd1..7522eb0 100644 --- a/bn_modeller/button_page.py +++ b/bn_modeller/button_page.py @@ -411,17 +411,17 @@ def on_acycle_graph(self): ) # удалить циклы в графе - G_before = copy.deepcopy(self.graph.renaming()) + G_before = copy.deepcopy(self.graph.getGraph()) self.graph.drop_cycle() self.changeLinkTable() - d = {"Before": G_before, "After": self.graph.renaming()} + d = {"Before": G_before, "After": self.graph.getGraph()} # import pickle # pickle.dump(self.graph, open('graph.txt', 'w')) # print(self.graph.G.nodes()) - nx.write_adjlist(self.graph.renaming(), "graph.txt") + nx.write_adjlist(self.graph.getGraph(), "graph.txt") dialog = SubplotGraph(data=d) self.dialogs.append(dialog) dialog.show() @@ -508,7 +508,7 @@ def changeLinkTable(self): columns=self.input_df.columns, index=self.input_df.columns ) - for i in self.graph.renaming().edges(data=True): + for i in self.graph.getGraph().edges(data=True): newLinkTab.loc[i[0], i[1]] = 1 newLinkTab = newLinkTab.fillna(0) self.updLinkTable = newLinkTab diff --git a/bn_modeller/models/bayesian_inference_model.py b/bn_modeller/models/bayesian_inference_model.py new file mode 100644 index 0000000..3304406 --- /dev/null +++ b/bn_modeller/models/bayesian_inference_model.py @@ -0,0 +1,51 @@ +from PySide6.QtCore import ( + QAbstractItemModel, + QModelIndex, + QObject, + QSortFilterProxyModel, + Qt, + Signal, + Slot, +) + + +class BayesianInferenceModel(QAbstractItemModel): + def __init__(self, parent: QObject = None): + super().__init__(parent) + + def headerData( + self, + section: int, + orientation: Qt.Orientation, + role: Qt.ItemDataRole = Qt.ItemDataRole, + ): + if role == Qt.DisplayRole: + if orientation == Qt.Horizontal: + return f"Column {section}" + elif orientation == Qt.Vertical: + return f"Row {section}" + return super().headerData(section, orientation, role) + + def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: + if not parent.isValid(): + return 10 + return 5 + + def columnCount(self, parent: QModelIndex = QModelIndex()) -> int: + if not parent.isValid(): + return None + return 2 + + def data(self, index: QModelIndex, role: Qt.ItemDataRole = Qt.ItemDataRole): + if not index.isValid(): + return None + row = index.row() + column = index.column() + if role == Qt.DisplayRole: + return f"Row {row}, Column {column}" + return super().data(index, role) + + +class BayesianInferenceIOOnlyModel(QSortFilterProxyModel): + def __init__(self, parent: QObject = None): + super().__init__(parent) diff --git a/bn_modeller/widgets/bn_visualization_view.py b/bn_modeller/widgets/bn_visualization_view.py index 473243c..e953933 100644 --- a/bn_modeller/widgets/bn_visualization_view.py +++ b/bn_modeller/widgets/bn_visualization_view.py @@ -16,7 +16,7 @@ def __init__(self, parent=None, width=12, height=12, dpi=100): super().__init__(self.fig) self.bn_ax = self.fig.add_subplot(1, 1, 1) - def update_plot(self, graph): + def update_plot(self, graph: nx.Graph): self.bn_ax.clear() # self.bn_ax = self.fig.add_subplot(1, 1, 1) @@ -116,7 +116,7 @@ def drawBN(self): graph.drop_cycle() # self.changeLinkTable() - self.bn_canvas.update_plot(graph.renaming()) + self.bn_canvas.update_plot(graph.getGraph()) # import pickle # pickle.dump(self.graph, open('graph.txt', 'w')) # print(self.graph.G.nodes()) diff --git a/bn_modeller/widgets/page/bayesian_inference_page.py b/bn_modeller/widgets/page/bayesian_inference_page.py new file mode 100644 index 0000000..31f0e7f --- /dev/null +++ b/bn_modeller/widgets/page/bayesian_inference_page.py @@ -0,0 +1,33 @@ +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QGridLayout, + QHBoxLayout, + QLabel, + QSplitter, + QTableView, + QTabWidget, + QVBoxLayout, + QWidget, +) + + +class BayesianInferencePageWidget(QWidget): + def __init__(self, parent: QWidget | None = None, f=Qt.WindowType()): + super().__init__(parent, f) + self._init_ui() + + def _init_ui(self): + + self.mainLayout = QVBoxLayout(self) + + # Top Layout with tables + + topLayout = QHBoxLayout() + self.tableSource = QTableView() + self.tableTarget = QTableView() + + topLayout.addWidget(self.tableSource) + topLayout.addWidget(self.tableTarget) + + # Finalization + self.setLayout(self.mainLayout) diff --git a/bn_modeller/widgets/page/bayesian_network_page.py b/bn_modeller/widgets/page/bayesian_network_page.py index 9f27522..b82aed6 100644 --- a/bn_modeller/widgets/page/bayesian_network_page.py +++ b/bn_modeller/widgets/page/bayesian_network_page.py @@ -15,6 +15,7 @@ ) from bn_modeller.widgets import DependencySetupTableView, SelectableListView from bn_modeller.widgets.bn_visualization_view import BayesianNetView +from bn_modeller.widgets.page.bayesian_inference_page import BayesianInferencePageWidget from bn_modeller.widgets.vertical_label import QVertivalLabel @@ -30,13 +31,13 @@ def _init_ui(self): # Dependency tab - # Feature selection + ## Feature selection self.dependencyTabWidget = QSplitter() self.featureSelectorView = SelectableListView() self.dependencyTabWidget.addWidget(self.featureSelectorView) self.dependencyTabWidget.setStretchFactor(0, 1) - # Dependency Table + ## Dependency Table depTableWidget = QWidget() depTableLayout = QGridLayout() @@ -60,6 +61,10 @@ def _init_ui(self): self.tabWidget.addTab(self.visualizationTabWidget, self.tr("Visulization")) + # Inference Tab + self.inferenceTabWidget = BayesianInferencePageWidget() + self.tabWidget.addTab(self.inferenceTabWidget, self.tr("Inference")) + # Finalization self.mainLayout.addWidget(self.tabWidget) self.setLayout(self.mainLayout) diff --git a/docs/GUI_sketches.drawio b/docs/GUI_sketches.drawio new file mode 100644 index 0000000..a317f58 --- /dev/null +++ b/docs/GUI_sketches.drawio @@ -0,0 +1,430 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +