Skip to content

Commit 97a03f2

Browse files
fix: fix ray redundant execution (#165)
* fix: fix ray redundant execution * Update graphgen/engine.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix: delete useless code * fix: update webui --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 1f13f44 commit 97a03f2

File tree

5 files changed

+61
-53
lines changed

5 files changed

+61
-53
lines changed

graphgen/engine.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,38 @@
88
import ray
99
import ray.data
1010
from ray.data import DataContext
11+
from ray.data.block import Block
12+
from ray.data.datasource.filename_provider import FilenameProvider
1113

1214
from graphgen.bases import Config, Node
1315
from graphgen.common import init_llm, init_storage
1416
from graphgen.utils import logger
1517

1618

19+
class NodeFilenameProvider(FilenameProvider):
20+
def __init__(self, node_id: str):
21+
self.node_id = node_id
22+
23+
def get_filename_for_block(
24+
self, block: Block, write_uuid: str, task_index: int, block_index: int
25+
) -> str:
26+
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.jsonl
27+
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
28+
29+
def get_filename_for_row(
30+
self,
31+
row: Dict[str, Any],
32+
write_uuid: str,
33+
task_index: int,
34+
block_index: int,
35+
row_index: int,
36+
) -> str:
37+
raise NotImplementedError(
38+
f"Row-based filenames are not supported by write_json. "
39+
f"Node: {self.node_id}, write_uuid: {write_uuid}"
40+
)
41+
42+
1743
class Engine:
1844
def __init__(
1945
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
@@ -263,13 +289,32 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
263289
f"Unsupported node type {node.type} for node {node.id}"
264290
)
265291

266-
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
292+
def execute(
293+
self, initial_ds: ray.data.Dataset, output_dir: str
294+
) -> Dict[str, ray.data.Dataset]:
267295
sorted_nodes = self._topo_sort(self.config.nodes)
268296

269297
for node in sorted_nodes:
298+
logger.info("Executing node %s of type %s", node.id, node.type)
270299
self._execute_node(node, initial_ds)
271300
if getattr(node, "save_output", False):
272-
self.datasets[node.id] = self.datasets[node.id].materialize()
301+
node_output_path = os.path.join(output_dir, f"{node.id}")
302+
os.makedirs(node_output_path, exist_ok=True)
303+
logger.info("Saving output of node %s to %s", node.id, node_output_path)
304+
305+
ds = self.datasets[node.id]
306+
ds.write_json(
307+
node_output_path,
308+
filename_provider=NodeFilenameProvider(node.id),
309+
pandas_json_args_fn=lambda: {
310+
"orient": "records",
311+
"lines": True,
312+
"force_ascii": False,
313+
},
314+
)
315+
logger.info("Node %s output saved to %s", node.id, node_output_path)
316+
317+
# ray will lazy read the dataset
318+
self.datasets[node.id] = ray.data.read_json(node_output_path)
273319

274-
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
275-
return {node.id: self.datasets[node.id] for node in output_nodes}
320+
return self.datasets

graphgen/operators/generate/generate_service.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
import pandas as pd
24

35
from graphgen.bases import BaseLLMWrapper, BaseOperator
@@ -85,7 +87,9 @@ def generate(self, items: list[dict]) -> list[dict]:
8587
:return: QA pairs
8688
"""
8789
logger.info("[Generation] mode: %s, batches: %d", self.method, len(items))
88-
items = [(item["nodes"], item["edges"]) for item in items]
90+
items = [
91+
(json.loads(item["nodes"]), json.loads(item["edges"])) for item in items
92+
]
8993
results = run_concurrent(
9094
self.generator.generate,
9195
items,

graphgen/operators/partition/partition_service.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ def partition(self) -> Iterable[pd.DataFrame]:
8989

9090
yield pd.DataFrame(
9191
{
92-
"nodes": [batch[0]],
93-
"edges": [batch[1]],
94-
}
92+
"nodes": json.dumps(batch[0]),
93+
"edges": json.dumps(batch[1]),
94+
},
95+
index=[0],
9596
)
9697
logger.info("Total communities partitioned: %d", count)
9798

graphgen/run.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
import os
33
import time
44
from importlib import resources
5-
from typing import Any, Dict
65

76
import ray
87
import yaml
98
from dotenv import load_dotenv
10-
from ray.data.block import Block
11-
from ray.data.datasource.filename_provider import FilenameProvider
129

1310
from graphgen.engine import Engine
1411
from graphgen.operators import operators
@@ -32,30 +29,6 @@ def save_config(config_path, global_config):
3229
)
3330

3431

35-
class NodeFilenameProvider(FilenameProvider):
36-
def __init__(self, node_id: str):
37-
self.node_id = node_id
38-
39-
def get_filename_for_block(
40-
self, block: Block, write_uuid: str, task_index: int, block_index: int
41-
) -> str:
42-
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
43-
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
44-
45-
def get_filename_for_row(
46-
self,
47-
row: Dict[str, Any],
48-
write_uuid: str,
49-
task_index: int,
50-
block_index: int,
51-
row_index: int,
52-
) -> str:
53-
raise NotImplementedError(
54-
f"Row-based filenames are not supported by write_json. "
55-
f"Node: {self.node_id}, write_uuid: {write_uuid}"
56-
)
57-
58-
5932
def main():
6033
parser = argparse.ArgumentParser()
6134
parser.add_argument(
@@ -91,22 +64,7 @@ def main():
9164

9265
engine = Engine(config, operators)
9366
ds = ray.data.from_items([])
94-
results = engine.execute(ds)
95-
96-
for node_id, dataset in results.items():
97-
logger.info("Saving results for node %s", node_id)
98-
node_output_path = os.path.join(output_path, f"{node_id}")
99-
os.makedirs(node_output_path, exist_ok=True)
100-
dataset.write_json(
101-
node_output_path,
102-
filename_provider=NodeFilenameProvider(node_id),
103-
pandas_json_args_fn=lambda: {
104-
"force_ascii": False,
105-
"orient": "records",
106-
"lines": True,
107-
},
108-
)
109-
logger.info("Node %s results saved to %s", node_id, node_output_path)
67+
engine.execute(ds, output_dir=output_path)
11068

11169
save_config(os.path.join(output_path, "config.yaml"), config)
11270
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

webui/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import gc
12
import json
23
import os
34
import sys
4-
import gc
55
import tempfile
66
from importlib.resources import files
77

@@ -188,7 +188,7 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
188188
ds = ray.data.from_items([])
189189

190190
# Execute pipeline
191-
results = engine.execute(ds)
191+
results = engine.execute(ds, output_dir=working_dir)
192192

193193
# 5. Process Output
194194
# Extract the result from the 'generate' node

0 commit comments

Comments
 (0)