Skip to content

Commit 9bbe7ee

Browse files
Feat: add trace info & task storage (#168)
* refactor: use kv_storage for cache of ParallelFileScanner * fix: limit before adding new data * wip: trace generated data * feat: add read_storage * fix: delete &quot in kg * wip: add checkpoint * wip: move storage logic to baseOperator * refactor: refactor evaluators * test: add e2e_test for triple_evaluation * fix: fix lint problem * fix: fix lint problem * fix: fix lint errors * fix: fix lint problems * fix: fix partition service
1 parent d0df8f0 commit 9bbe7ee

File tree

77 files changed

+1266
-2167
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+1266
-2167
lines changed

baselines/BDS/bds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from graphgen.bases import BaseLLMWrapper
1111
from graphgen.common import init_llm
12-
from graphgen.models import NetworkXStorage
12+
from graphgen.storage import NetworkXStorage
1313
from graphgen.utils import create_event_loop
1414

1515
QA_GENERATION_PROMPT = """

examples/evaluate/evaluate_kg/kg_evaluation_config.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ nodes:
1010
dependencies: []
1111
params:
1212
input_path:
13-
- examples/input_examples/extract_demo.txt
13+
- examples/input_examples/jsonl_demo.jsonl
1414

1515
- id: chunk
1616
op_name: chunk
@@ -39,7 +39,6 @@ nodes:
3939
dependencies:
4040
- build_kg
4141
params:
42+
target: kg
4243
metrics:
43-
- kg_structure
44-
- kg_accuracy
45-
- kg_consistency
44+
- structure

examples/evaluate/evaluate_qa/qa_evaluation_config.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
global_params:
22
working_dir: cache
3-
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4-
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
55

66
nodes:
77
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
@@ -89,10 +89,11 @@ nodes:
8989
batch_size: 128
9090
save_output: true
9191
params:
92+
target: qa
9293
metrics:
93-
- qa_length
94-
- qa_mtld
95-
# - qa_reward_score
96-
# - qa_uni_score
94+
- length
95+
- mtld
96+
# - reward_score
97+
# - uni_score
9798
mtld_params:
9899
threshold: 0.7
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_triple/triple_evaluation_config.yaml
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl
14+
15+
- id: chunk
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 20480 # larger chunk size for better context
24+
chunk_overlap: 2000
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: evaluate
36+
op_name: evaluate
37+
type: aggregate
38+
save_output: true
39+
dependencies:
40+
- build_kg
41+
params:
42+
target: triple
43+
src_namespace: chunk
44+
tgt_namespace: build_kg
45+
metrics:
46+
- accuracy

graphgen/bases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
12
from .base_extractor import BaseExtractor
23
from .base_generator import BaseGenerator
34
from .base_kg_builder import BaseKGBuilder
@@ -9,5 +10,4 @@
910
from .base_splitter import BaseSplitter
1011
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
1112
from .base_tokenizer import BaseTokenizer
12-
from .base_evaluator import BaseEvaluator
1313
from .datatypes import Chunk, Config, Node, QAPair, Token

graphgen/bases/base_evaluator.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from .base_storage import BaseGraphStorage
25
from .datatypes import QAPair
36

47

5-
class BaseEvaluator(ABC):
8+
class BaseQAEvaluator(ABC):
69
@abstractmethod
7-
def evaluate(self, pair: QAPair) -> float:
10+
async def evaluate(self, pair: QAPair) -> dict[str, float]:
811
"""
912
Evaluate the text and return a score.
1013
"""
14+
15+
16+
class BaseKGEvaluator(ABC):
17+
@abstractmethod
18+
def evaluate(self, kg: BaseGraphStorage) -> dict[str, Any]:
19+
"""
20+
Evaluate the whole graph and return a dict of scores.
21+
"""
22+
23+
24+
class BaseTripleEvaluator(ABC):
25+
@abstractmethod
26+
async def evaluate(self, unit: dict) -> dict[str, float]:
27+
"""
28+
Evaluate a node/edge and return a score.
29+
"""

graphgen/bases/base_generator.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,72 +21,57 @@ def build_prompt(
2121

2222
@staticmethod
2323
@abstractmethod
24-
def parse_response(response: str) -> Any:
24+
def parse_response(response: str) -> list[dict]:
2525
"""Parse the LLM response and return the generated QAs"""
2626

2727
async def generate(
2828
self,
2929
batch: tuple[
3030
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
3131
],
32-
) -> dict[str, Any]:
32+
) -> list[dict]:
3333
"""
3434
Generate QAs based on a given batch.
3535
:param batch
3636
:return: QA pairs
3737
"""
38-
result = {}
3938
prompt = self.build_prompt(batch)
4039
response = await self.llm_client.generate_answer(prompt)
4140
qa_pairs = self.parse_response(response) # generate one or more QA pairs
42-
result.update(qa_pairs)
43-
return result
41+
return qa_pairs
4442

4543
@staticmethod
4644
def format_generation_results(
47-
results: list[dict], output_data_format: str
48-
) -> list[dict[str, Any]]:
45+
result: dict, output_data_format: str
46+
) -> dict[str, Any]:
47+
question = result.get("question", "")
48+
answer = result.get("answer", "")
49+
if "options" in result and result["options"]:
50+
options = result["options"]
51+
options_str = "\n".join(
52+
[f"{key}. {options[key]}" for key in sorted(options.keys())]
53+
)
54+
question += f"\nOptions:\n{options_str}"
4955

50-
flat_results = []
51-
for item in results:
52-
for _, qa_data in item.items():
53-
question = qa_data.get("question", "")
54-
answer = qa_data.get("answer", "")
55-
if "options" in qa_data and qa_data["options"]:
56-
options = qa_data["options"]
57-
options_str = "\n".join(
58-
[f"{key}. {options[key]}" for key in sorted(options.keys())]
59-
)
60-
question += f"\nOptions:\n{options_str}"
56+
if output_data_format == "Alpaca":
57+
return {
58+
"instruction": question,
59+
"input": "",
60+
"output": answer,
61+
}
6162

62-
if output_data_format == "Alpaca":
63-
flat_results.append(
64-
{
65-
"instruction": question,
66-
"input": "",
67-
"output": answer,
68-
}
69-
)
70-
elif output_data_format == "Sharegpt":
71-
flat_results.append(
72-
{
73-
"conversations": [
74-
{"from": "human", "value": question},
75-
{"from": "gpt", "value": answer},
76-
]
77-
}
78-
)
79-
elif output_data_format == "ChatML":
80-
flat_results.append(
81-
{
82-
"messages": [
83-
{"role": "user", "content": question},
84-
{"role": "assistant", "content": answer},
85-
]
86-
}
87-
)
88-
else:
89-
raise ValueError(
90-
f"Unknown output data format: {output_data_format}"
91-
)
92-
return flat_results
63+
if output_data_format == "Sharegpt":
64+
return {
65+
"conversations": [
66+
{"from": "human", "value": question},
67+
{"from": "gpt", "value": answer},
68+
]
69+
}
70+
if output_data_format == "ChatML":
71+
return {
72+
"messages": [
73+
{"role": "user", "content": question},
74+
{"role": "assistant", "content": answer},
75+
]
76+
}
77+
raise ValueError(f"Unknown output data format: {output_data_format}")

0 commit comments

Comments
 (0)