Skip to content

Commit 6a0f682

Browse files
feat: add rephrasing prompts (#163)
* feat: add rephrasing prompts * Update graphgen/templates/rephrasing/qa_dialogue_format_rephrasing.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * feat: add rephrasing pipeline * fix: change chunk params * fix: fix rephrasers --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 661211e commit 6a0f682

37 files changed

+711
-20
lines changed

baselines/BDS/bds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tqdm.asyncio import tqdm as tqdm_async
99

1010
from graphgen.bases import BaseLLMWrapper
11-
from graphgen.common import init_llm
11+
from graphgen.common.init_llm import init_llm
1212
from graphgen.storage import NetworkXStorage
1313
from graphgen.utils import create_event_loop
1414

examples/generate/generate_aggregated_qa/aggregated_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
global_params:
22
working_dir: cache
3-
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4-
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
55

66
nodes:
77
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Rephrase with Style Control
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/rephrase/rephrase_style_controlled/style_controlled_rephrasing_config.yaml
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
global_params:
2+
working_dir: cache
3+
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
4+
5+
nodes:
6+
- id: read
7+
op_name: read
8+
type: source
9+
dependencies: []
10+
params:
11+
input_path:
12+
- examples/input_examples/json_demo.json
13+
14+
- id: chunk
15+
op_name: chunk
16+
type: map_batch
17+
dependencies:
18+
- read
19+
execution_params:
20+
replicas: 4
21+
params:
22+
chunk_size: 2048 # larger chunk size for better context
23+
chunk_overlap: 200
24+
25+
- id: rephrase
26+
op_name: rephrase
27+
type: map_batch
28+
dependencies:
29+
- chunk
30+
execution_params:
31+
replicas: 1
32+
batch_size: 128
33+
save_output: true
34+
params:
35+
method: style_controlled
36+
style: critical_analysis

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .base_operator import BaseOperator
88
from .base_partitioner import BasePartitioner
99
from .base_reader import BaseReader
10+
from .base_rephraser import BaseRephraser
1011
from .base_searcher import BaseSearcher
1112
from .base_splitter import BaseSplitter
1213
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace

graphgen/bases/base_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
op_name: str = None,
2929
):
3030
# lazy import to avoid circular import
31-
from graphgen.common import init_storage
31+
from graphgen.common.init_storage import init_storage
3232
from graphgen.utils import set_logger
3333

3434
log_dir = os.path.join(working_dir, "logs")

graphgen/bases/base_rephraser.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
5+
6+
7+
class BaseRephraser(ABC):
8+
"""
9+
Rephrase text based on given prompts.
10+
"""
11+
12+
def __init__(self, llm_client: BaseLLMWrapper):
13+
self.llm_client = llm_client
14+
15+
@abstractmethod
16+
def build_prompt(self, text: str) -> str:
17+
"""Build prompt for LLM based on the given text"""
18+
19+
@staticmethod
20+
@abstractmethod
21+
def parse_response(response: str) -> Any:
22+
"""Parse the LLM response and return the rephrased text"""
23+
24+
async def rephrase(
25+
self,
26+
item: dict,
27+
) -> dict:
28+
text = item["content"]
29+
prompt = self.build_prompt(text)
30+
response = await self.llm_client.generate_answer(prompt)
31+
return self.parse_response(response)

graphgen/common/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .init_llm import init_llm
2-
from .init_storage import init_storage
1+
# from .init_llm import init_llm
2+
# from .init_storage import init_storage

graphgen/engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from ray.data.datasource.filename_provider import FilenameProvider
1212

1313
from graphgen.bases import Config, Node
14-
from graphgen.common import init_llm, init_storage
14+
from graphgen.common.init_llm import init_llm
15+
from graphgen.common.init_storage import init_storage
1516
from graphgen.utils import logger
1617

1718

@@ -70,6 +71,7 @@ def __init__(
7071

7172
if not ray.is_initialized():
7273
context = ray.init(
74+
include_dashboard=True,
7375
ignore_reinit_error=True,
7476
logging_level=logging.ERROR,
7577
log_to_driver=True,

0 commit comments

Comments
 (0)