diff --git a/graphgen/engine.py b/graphgen/engine.py index 2f1abf61..d09eb106 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -8,12 +8,38 @@ import ray import ray.data from ray.data import DataContext +from ray.data.block import Block +from ray.data.datasource.filename_provider import FilenameProvider from graphgen.bases import Config, Node from graphgen.common import init_llm, init_storage from graphgen.utils import logger +class NodeFilenameProvider(FilenameProvider): + def __init__(self, node_id: str): + self.node_id = node_id + + def get_filename_for_block( + self, block: Block, write_uuid: str, task_index: int, block_index: int + ) -> str: + # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.jsonl + return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" + + def get_filename_for_row( + self, + row: Dict[str, Any], + write_uuid: str, + task_index: int, + block_index: int, + row_index: int, + ) -> str: + raise NotImplementedError( + f"Row-based filenames are not supported by write_json. " + f"Node: {self.node_id}, write_uuid: {write_uuid}" + ) + + class Engine: def __init__( self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs @@ -263,13 +289,32 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: f"Unsupported node type {node.type} for node {node.id}" ) - def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: + def execute( + self, initial_ds: ray.data.Dataset, output_dir: str + ) -> Dict[str, ray.data.Dataset]: sorted_nodes = self._topo_sort(self.config.nodes) for node in sorted_nodes: + logger.info("Executing node %s of type %s", node.id, node.type) self._execute_node(node, initial_ds) if getattr(node, "save_output", False): - self.datasets[node.id] = self.datasets[node.id].materialize() + node_output_path = os.path.join(output_dir, f"{node.id}") + os.makedirs(node_output_path, exist_ok=True) + logger.info("Saving output of node %s to %s", node.id, node_output_path) + + ds = self.datasets[node.id] + ds.write_json( + node_output_path, + filename_provider=NodeFilenameProvider(node.id), + pandas_json_args_fn=lambda: { + "orient": "records", + "lines": True, + "force_ascii": False, + }, + ) + logger.info("Node %s output saved to %s", node.id, node_output_path) + + # ray will lazy read the dataset + self.datasets[node.id] = ray.data.read_json(node_output_path) - output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)] - return {node.id: self.datasets[node.id] for node in output_nodes} + return self.datasets diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index 104ab88f..2107876b 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -1,3 +1,5 @@ +import json + import pandas as pd from graphgen.bases import BaseLLMWrapper, BaseOperator @@ -85,7 +87,9 @@ def generate(self, items: list[dict]) -> list[dict]: :return: QA pairs """ logger.info("[Generation] mode: %s, batches: %d", self.method, len(items)) - items = [(item["nodes"], item["edges"]) for item in items] + items = [ + (json.loads(item["nodes"]), json.loads(item["edges"])) for item in items + ] results = run_concurrent( self.generator.generate, items, diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index ff215fce..6622e411 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -89,9 +89,10 @@ def partition(self) -> Iterable[pd.DataFrame]: yield pd.DataFrame( { - "nodes": [batch[0]], - "edges": [batch[1]], - } + "nodes": json.dumps(batch[0]), + "edges": json.dumps(batch[1]), + }, + index=[0], ) logger.info("Total communities partitioned: %d", count) diff --git a/graphgen/run.py b/graphgen/run.py index 6b303ee1..26e752ae 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -2,13 +2,10 @@ import os import time from importlib import resources -from typing import Any, Dict import ray import yaml from dotenv import load_dotenv -from ray.data.block import Block -from ray.data.datasource.filename_provider import FilenameProvider from graphgen.engine import Engine from graphgen.operators import operators @@ -32,30 +29,6 @@ def save_config(config_path, global_config): ) -class NodeFilenameProvider(FilenameProvider): - def __init__(self, node_id: str): - self.node_id = node_id - - def get_filename_for_block( - self, block: Block, write_uuid: str, task_index: int, block_index: int - ) -> str: - # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json - return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" - - def get_filename_for_row( - self, - row: Dict[str, Any], - write_uuid: str, - task_index: int, - block_index: int, - row_index: int, - ) -> str: - raise NotImplementedError( - f"Row-based filenames are not supported by write_json. " - f"Node: {self.node_id}, write_uuid: {write_uuid}" - ) - - def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -91,22 +64,7 @@ def main(): engine = Engine(config, operators) ds = ray.data.from_items([]) - results = engine.execute(ds) - - for node_id, dataset in results.items(): - logger.info("Saving results for node %s", node_id) - node_output_path = os.path.join(output_path, f"{node_id}") - os.makedirs(node_output_path, exist_ok=True) - dataset.write_json( - node_output_path, - filename_provider=NodeFilenameProvider(node_id), - pandas_json_args_fn=lambda: { - "force_ascii": False, - "orient": "records", - "lines": True, - }, - ) - logger.info("Node %s results saved to %s", node_id, node_output_path) + engine.execute(ds, output_dir=output_path) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/webui/app.py b/webui/app.py index 08f1907a..140122e4 100644 --- a/webui/app.py +++ b/webui/app.py @@ -1,7 +1,7 @@ +import gc import json import os import sys -import gc import tempfile from importlib.resources import files @@ -188,7 +188,7 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()): ds = ray.data.from_items([]) # Execute pipeline - results = engine.execute(ds) + results = engine.execute(ds, output_dir=working_dir) # 5. Process Output # Extract the result from the 'generate' node