Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions graphgen/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,38 @@
import ray
import ray.data
from ray.data import DataContext
from ray.data.block import Block
from ray.data.datasource.filename_provider import FilenameProvider

from graphgen.bases import Config, Node
from graphgen.common import init_llm, init_storage
from graphgen.utils import logger


class NodeFilenameProvider(FilenameProvider):
def __init__(self, node_id: str):
self.node_id = node_id

def get_filename_for_block(
self, block: Block, write_uuid: str, task_index: int, block_index: int
) -> str:
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.jsonl
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"

def get_filename_for_row(
self,
row: Dict[str, Any],
write_uuid: str,
task_index: int,
block_index: int,
row_index: int,
) -> str:
raise NotImplementedError(
f"Row-based filenames are not supported by write_json. "
f"Node: {self.node_id}, write_uuid: {write_uuid}"
)


class Engine:
def __init__(
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
Expand Down Expand Up @@ -263,13 +289,32 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
f"Unsupported node type {node.type} for node {node.id}"
)

def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
def execute(
self, initial_ds: ray.data.Dataset, output_dir: str
) -> Dict[str, ray.data.Dataset]:
sorted_nodes = self._topo_sort(self.config.nodes)

for node in sorted_nodes:
logger.info("Executing node %s of type %s", node.id, node.type)
self._execute_node(node, initial_ds)
if getattr(node, "save_output", False):
self.datasets[node.id] = self.datasets[node.id].materialize()
node_output_path = os.path.join(output_dir, f"{node.id}")
os.makedirs(node_output_path, exist_ok=True)
logger.info("Saving output of node %s to %s", node.id, node_output_path)

ds = self.datasets[node.id]
ds.write_json(
node_output_path,
filename_provider=NodeFilenameProvider(node.id),
pandas_json_args_fn=lambda: {
"orient": "records",
"lines": True,
"force_ascii": False,
},
)
logger.info("Node %s output saved to %s", node.id, node_output_path)

# ray will lazy read the dataset
self.datasets[node.id] = ray.data.read_json(node_output_path)

output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
return {node.id: self.datasets[node.id] for node in output_nodes}
return self.datasets
6 changes: 5 additions & 1 deletion graphgen/operators/generate/generate_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import pandas as pd

from graphgen.bases import BaseLLMWrapper, BaseOperator
Expand Down Expand Up @@ -85,7 +87,9 @@ def generate(self, items: list[dict]) -> list[dict]:
:return: QA pairs
"""
logger.info("[Generation] mode: %s, batches: %d", self.method, len(items))
items = [(item["nodes"], item["edges"]) for item in items]
items = [
(json.loads(item["nodes"]), json.loads(item["edges"])) for item in items
]
results = run_concurrent(
self.generator.generate,
items,
Expand Down
7 changes: 4 additions & 3 deletions graphgen/operators/partition/partition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ def partition(self) -> Iterable[pd.DataFrame]:

yield pd.DataFrame(
{
"nodes": [batch[0]],
"edges": [batch[1]],
}
"nodes": json.dumps(batch[0]),
"edges": json.dumps(batch[1]),
},
index=[0],
)
logger.info("Total communities partitioned: %d", count)

Expand Down
44 changes: 1 addition & 43 deletions graphgen/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import os
import time
from importlib import resources
from typing import Any, Dict

import ray
import yaml
from dotenv import load_dotenv
from ray.data.block import Block
from ray.data.datasource.filename_provider import FilenameProvider

from graphgen.engine import Engine
from graphgen.operators import operators
Expand All @@ -32,30 +29,6 @@ def save_config(config_path, global_config):
)


class NodeFilenameProvider(FilenameProvider):
def __init__(self, node_id: str):
self.node_id = node_id

def get_filename_for_block(
self, block: Block, write_uuid: str, task_index: int, block_index: int
) -> str:
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"

def get_filename_for_row(
self,
row: Dict[str, Any],
write_uuid: str,
task_index: int,
block_index: int,
row_index: int,
) -> str:
raise NotImplementedError(
f"Row-based filenames are not supported by write_json. "
f"Node: {self.node_id}, write_uuid: {write_uuid}"
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -91,22 +64,7 @@ def main():

engine = Engine(config, operators)
ds = ray.data.from_items([])
results = engine.execute(ds)

for node_id, dataset in results.items():
logger.info("Saving results for node %s", node_id)
node_output_path = os.path.join(output_path, f"{node_id}")
os.makedirs(node_output_path, exist_ok=True)
dataset.write_json(
node_output_path,
filename_provider=NodeFilenameProvider(node_id),
pandas_json_args_fn=lambda: {
"force_ascii": False,
"orient": "records",
"lines": True,
},
)
logger.info("Node %s results saved to %s", node_id, node_output_path)
engine.execute(ds, output_dir=output_path)

save_config(os.path.join(output_path, "config.yaml"), config)
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
Expand Down
4 changes: 2 additions & 2 deletions webui/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gc
import json
import os
import sys
import gc
import tempfile
from importlib.resources import files

Expand Down Expand Up @@ -188,7 +188,7 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
ds = ray.data.from_items([])

# Execute pipeline
results = engine.execute(ds)
results = engine.execute(ds, output_dir=working_dir)

# 5. Process Output
# Extract the result from the 'generate' node
Expand Down
Loading