Skip to content

Commit 9e83654

Browse files
fix: update anchor_bfs_partitioner & fix vqa_generator (#166)
1 parent 97a03f2 commit 9e83654

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

graphgen/models/generator/vqa_generator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def format_generation_results(
106106
{"text": v["question"], "image": v.get("img_path", "")}
107107
],
108108
},
109-
{"from": "gpt", "value": v["answer"]},
109+
{"from": "gpt", "value": [{"text": v["answer"]}]},
110110
]
111111
}
112112
for item in results
@@ -122,7 +122,10 @@ def format_generation_results(
122122
{"text": v["question"], "image": v.get("img_path", "")}
123123
],
124124
},
125-
{"role": "assistant", "content": v["answer"]},
125+
{
126+
"role": "assistant",
127+
"content": [{"type": "text", "text": v["answer"]}],
128+
},
126129
]
127130
}
128131
for item in results

graphgen/models/partitioner/anchor_bfs_partitioner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def partition(
3737
**kwargs: Any,
3838
) -> Iterable[Community]:
3939
nodes = g.get_all_nodes() # List[tuple[id, meta]]
40-
edges = g.get_all_edges() # List[tuple[u, v, meta]]
41-
42-
adj, _ = self._build_adjacency_list(nodes, edges)
4340

4441
anchors: Set[str] = self._pick_anchor_ids(nodes)
4542
if not anchors:
@@ -55,7 +52,7 @@ def partition(
5552
if seed_node in used_n:
5653
continue
5754
comm_n, comm_e = self._grow_community(
58-
seed_node, adj, max_units_per_community, used_n, used_e
55+
seed_node, g, max_units_per_community, used_n, used_e
5956
)
6057
if comm_n or comm_e:
6158
yield Community(id=seed_node, nodes=comm_n, edges=comm_e)
@@ -77,15 +74,15 @@ def _pick_anchor_ids(
7774
@staticmethod
7875
def _grow_community(
7976
seed: str,
80-
adj: dict[str, List[str]],
77+
g: BaseGraphStorage,
8178
max_units: int,
8279
used_n: set[str],
8380
used_e: set[frozenset[str]],
8481
) -> Tuple[List[str], List[Tuple[str, str]]]:
8582
"""
8683
Grow a community from the seed node using BFS.
8784
:param seed: seed node id
88-
:param adj: adjacency list
85+
:param g: graph storage
8986
:param max_units: maximum number of units (nodes + edges) in the community
9087
:param used_n: set of used node ids
9188
:param used_e: set of used edge keys
@@ -105,7 +102,7 @@ def _grow_community(
105102
used_n.add(it)
106103
comm_n.append(it)
107104
cnt += 1
108-
for nei in adj[it]:
105+
for nei in g.get_neighbors(it):
109106
e_key = frozenset((it, nei))
110107
if e_key not in used_e:
111108
queue.append((EDGE_UNIT, e_key))

0 commit comments

Comments
 (0)