diff --git a/demo/condmatgen_tiny/NatureLM_conditional_v2.json b/demo/condmatgen_tiny/NatureLM_conditional_v2.json new file mode 100644 index 00000000..767c3769 --- /dev/null +++ b/demo/condmatgen_tiny/NatureLM_conditional_v2.json @@ -0,0 +1,52 @@ +[ + {"elements": "O Te Tm"}, + {"elements": "Si Al O"}, + {"elements": "Fe Ni"}, + {"elements": "Li Co O"}, + {"elements": "Na Cl"}, + {"elements": "Ca Ti O"}, + {"elements": "Mg O"}, + {"elements": "Zn S"}, + {"elements": "Cu Fe S"}, + {"elements": "Ba Ti O"}, + {"elements": "Sr Ti O"}, + {"elements": "Ga As"}, + {"elements": "In P"}, + {"elements": "Al N"}, + {"elements": "Si C"}, + {"elements": "Fe O"}, + {"elements": "Mn O"}, + {"elements": "Cr O"}, + {"elements": "V O"}, + {"elements": "Ti O"}, + {"elements": "Li Mn O"}, + {"elements": "Li Fe P O"}, + {"elements": "Na Fe O"}, + {"elements": "K Al Si O"}, + {"elements": "Ca Mg Si O"}, + {"elements": "Zr O"}, + {"elements": "Hf O"}, + {"elements": "Nb O"}, + {"elements": "Mo S"}, + {"elements": "W Se"}, + {"elements": "Bi Te"}, + {"elements": "Sb Te"}, + {"elements": "Pb Te"}, + {"elements": "Sn Se"}, + {"elements": "Ge Te"}, + {"elements": "Cd Se"}, + {"elements": "Zn Te"}, + {"elements": "Cu Zn Sn S"}, + {"elements": "Ag Bi Se"}, + {"elements": "La Mn O"}, + {"elements": "Y Ba Cu O"}, + {"elements": "Nd Fe B"}, + {"elements": "Sm Co"}, + {"elements": "Ce O"}, + {"elements": "Pr Ni O"}, + {"elements": "Gd Fe O"}, + {"elements": "Eu Ti O"}, + {"elements": "Tb Mn O"}, + {"elements": "Dy Fe O"}, + {"elements": "Ho Co O"} +] diff --git a/demo/fixture_manifest.csv b/demo/fixture_manifest.csv index 1996d9a7..1c835e69 100644 --- a/demo/fixture_manifest.csv +++ b/demo/fixture_manifest.csv @@ -12,3 +12,4 @@ rxn_inversion,demo/datasets/rxn_inversion_sample.csv,50,"~45 train / ~5 test",re rxn_replacement,demo/datasets/rxn_inversion_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Same schema as rxn_inversion (MCQ with 4 options)" rxn_naming,demo/datasets/rxn_naming_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Reaction classification into 10 named categories" rxn_truefalse,demo/datasets/rxn_truefalse_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Binary true/false reaction validity" +condmatgen,demo/condmatgen_tiny,50,"~45 train / ~5 test",repo-local fixture,ready,"JSON with element lists for conditional material generation" diff --git a/demo/run_fixture_smoke.py b/demo/run_fixture_smoke.py index 4316f664..937de084 100644 --- a/demo/run_fixture_smoke.py +++ b/demo/run_fixture_smoke.py @@ -43,6 +43,9 @@ def main(): "rxn_truefalse": datasets_dir / "rxn_truefalse_sample.csv", } + if "condmatgen" in CHEMTASKS: + task_configs["condmatgen"] = demo_dir / "condmatgen_tiny" + summary = {} for task_name, dataset_path in task_configs.items(): task_class = CHEMTASKS[task_name] diff --git a/docs/source/modules.rst b/docs/source/modules.rst index e9a2c264..a930a00c 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -17,4 +17,5 @@ Modules Reference tasks/rxn_naming tasks/rxn_truefalse tasks/template - tasks/smi_permute \ No newline at end of file + tasks/smi_permute + tasks/condmatgen \ No newline at end of file diff --git a/docs/source/tasks/condmatgen.rst b/docs/source/tasks/condmatgen.rst new file mode 100644 index 00000000..997b5532 --- /dev/null +++ b/docs/source/tasks/condmatgen.rst @@ -0,0 +1,26 @@ +Conditional Material Generation (CMG) +====================================== + +.. currentmodule:: open_r1.tasks.condmatgen.condmatgen + +ConditionalMaterialGeneration +----------------------------- + +.. autoclass:: ConditionalMaterialGeneration + :members: + :show-inheritance: + +Task Description +---------------- + +Given a set of chemical elements, the model proposes a novel crystalline +compound (element list and space group number). The model wraps its reasoning +in ``...`` tags and its answer in ``...`` tags. + +Reward Functions +---------------- + +- **accuracy**: multi-component scoring including SMACT validity, element + precision, space group validity, and novelty bonus. +- **format**: checks presence and ordering of think/answer tags, penalizes + short reasoning. diff --git a/recipes/condmatgen.yaml b/recipes/condmatgen.yaml new file mode 100644 index 00000000..028afba3 --- /dev/null +++ b/recipes/condmatgen.yaml @@ -0,0 +1,49 @@ +# Model arguments +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true + +# Chemical Task arguments +chem_task: condmatgen +dataset_id_or_path: ${MIST_DATA_DIR}/condmatgen +rewards: +- accuracy + +# Lora Arguments +# No LoRA is used here + +# Training arguments +max_steps: 1450 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +learning_rate: 2.0e-6 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +lr_scheduler_type: cosine +warmup_ratio: 0.03 +# GRPO specific parameters +beta: 0.04 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +max_prompt_length: 600 +max_completion_length: 2048 +num_generations: 4 +use_vllm: true +vllm_device: "cuda:3" +vllm_gpu_memory_utilization: 0.8 +vllm_max_model_len: 2048 + +# Logging arguments +logging_strategy: steps +logging_steps: 1 +report_to: +- wandb + +save_strategy: "steps" +save_steps: 25 +seed: 42 + +# Hugging Face Hub +push_to_hub: false + # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir \ No newline at end of file diff --git a/src/open_r1/tasks/__init__.py b/src/open_r1/tasks/__init__.py index 8812ce63..b477254c 100644 --- a/src/open_r1/tasks/__init__.py +++ b/src/open_r1/tasks/__init__.py @@ -10,6 +10,11 @@ from .reactions.smi_permute import PermuteSmiles from .smiles_understanding.smiles_hydrogen import SmilesHydrogen +try: + from .condmatgen.condmatgen import ConditionalMaterialGeneration +except ImportError: + ConditionalMaterialGeneration = None + # Task keys as specified in the task recipes and documentation CHEMTASKS = { "rxnpred_with_tags": ForwardReactionWithTags, @@ -26,3 +31,6 @@ "rxn_naming": Smiles2Name, "rxn_truefalse": ReactionTrueFalse, } + +if ConditionalMaterialGeneration is not None: + CHEMTASKS["condmatgen"] = ConditionalMaterialGeneration diff --git a/src/open_r1/tasks/condmatgen/__init__.py b/src/open_r1/tasks/condmatgen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/open_r1/tasks/condmatgen/comps_used_in_sft.json b/src/open_r1/tasks/condmatgen/comps_used_in_sft.json new file mode 100644 index 00000000..0637a088 --- /dev/null +++ b/src/open_r1/tasks/condmatgen/comps_used_in_sft.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/src/open_r1/tasks/condmatgen/condmatgen.py b/src/open_r1/tasks/condmatgen/condmatgen.py new file mode 100644 index 00000000..6862cd17 --- /dev/null +++ b/src/open_r1/tasks/condmatgen/condmatgen.py @@ -0,0 +1,249 @@ +import json +import os +import re +from dataclasses import field +from typing import Any, Dict + +from datasets import Dataset, DatasetDict + +from open_r1.paths import expand_path + +from ..base import RLTask + + +class ConditionalMaterialGeneration(RLTask): + question_template: str = "" + log_custom_metrics: bool = True + custom_metrics: dict = field(default_factory=dict) + seen_comps_set: set = field(default_factory=set) + random_log: Dict[str, Any] = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.system_prompt = ( + "You are a careful model that must follow the Output Contract exactly.\n\n" + "OUTPUT CONTRACT\n" + "1) You may think only inside ....\n" + "2) Your final answer must be a single line wrapped in ....\n" + "3) Inside , list only element symbols separated by single spaces, " + "followed by a space-group tag .\n" + "4) If you generate a chemical formula with subscripts (e.g. Ni\u2082Fe\u2084LiO\u2081\u2080), " + "you must expand them into repeated symbols " + "(Ni Ni Fe Fe Fe Fe Li O O O O O O O O O O).\n" + "5) Do not include any extra words, punctuation, examples, explanations, " + "or text outside the tags.\n" + "6) After you produce , you must stop. No tokens are allowed after .\n\n" + "VALID EXAMPLE (format only):\n" + "reasoning\n" + " Ca O Sn Sn \n\n" + "INVALID EXAMPLES:\n" + "- Missing tags\n" + "- Commas, bullets, or explanations inside \n" + "- Multiple blocks\n" + "- Extra output after \n\n" + "Allowed tokens inside : element symbols (H He Li ... Og), " + "single spaces, and the literal pattern ." + ) + + self.question_template = ( + "You are a materials science expert.\n" + "Given the following elements: {}, propose one chemically valid and novel crystalline compound." + ) + + self.log_custom_metrics = True + self.custom_metrics = { + "val/rewards": [], + } + + comps_path = os.path.join(os.path.dirname(__file__), "comps_used_in_sft.json") + with open(comps_path, "r") as file: + seen_comps = json.load(file) + + from pymatgen.core import Composition + + self.seen_comps_set = set() + for comp in seen_comps: + comp = Composition(comp) + self.seen_comps_set.add(comp) + + def read_files(self) -> Dict: + dataset_path = expand_path(os.path.join(self.dataset_id_or_path, "NatureLM_conditional_v2.json")) + with open(dataset_path, "r") as file: + data = json.load(file) + + problems = [] + solutions = [] + + for pt in data: + try: + problems.append(pt.get("elements")) + solutions.append("") + except KeyError as e: + print(f"Missing expected key in data: {e}") + + return { + "problem": problems, + "solution": solutions, + } + + def generate_prompt(self, problem, tokenizer, **kwargs): + r1_prefix = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": self.question_template.format(problem), + }, + ] + return { + "prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), + "problem": problem, + } + + def load(self) -> DatasetDict: + """Load and return the complete dataset.""" + train_dict = self.read_files() + train_dataset = Dataset.from_dict(train_dict) + seed = 42 + train_test_split = train_dataset.train_test_split(test_size=0.1, seed=seed) + train_dataset = train_test_split["train"] + test_dataset = train_test_split["test"] + + self.dataset = DatasetDict({"train": train_dataset, "test": test_dataset}) + return self.dataset + + def accuracy_reward(self, completions, solution, prompts, **kwargs): + """Reward function - check that completion is same as ground truth.""" + rewards = [] + + for completion, prompt in zip(completions, prompts): + reward = 0 + + # Format + think_start = completion.find("") + think_end = completion.find("") + answer_start = completion.find("") + answer_end = completion.find("") + + if think_start != -1: + reward += 0.25 + if think_end != -1: + reward += 0.25 + if answer_start != -1: + reward += 0.25 + if answer_end != -1: + reward += 0.25 + + if think_start != -1 and think_end != -1: + if think_start < think_end: + reward += 0.25 + if answer_start != -1 and answer_end != -1: + if answer_start < answer_end: + reward += 0.25 + if think_start != -1 and think_end != -1 and answer_start != -1 and answer_end != -1: + if think_start < think_end and answer_start < answer_end and think_end < answer_start: + reward += 1 + else: + reward -= 1 + + if completion.strip().endswith(""): + reward += 1 + else: + reward -= 2 + + matches = re.findall(r"(.*?)", completion, flags=re.DOTALL) + if matches: + reasoning_len = len(matches[-1]) + else: + reasoning_len = 0 + + if reasoning_len < 500: + reward -= 5 + + # Extract elements from instruction + input_pattern = r"(?i)elements:\s*(.*?)\s*,\s*propose\b" + match = re.search(input_pattern, prompt) + input_elements = match.group(1).split(", ") if match else [] + + # Extract elements and space group from output + output_pattern = r"\s*((?:[A-Z][a-z]?\s*)+?)\s*\s*" + output_matches = re.findall(output_pattern, completion) + if len(output_matches) < 1: + rewards.append(reward) + continue + elif len(output_matches) == 1: + reward += 1 + else: + reward -= 1 + elements_str, sg_str = output_matches[-1] + output_sg = int(sg_str.strip()) + + if not 1 <= output_sg <= 230: + rewards.append(reward) + continue + reward += 1 + + output_elements = elements_str.strip().split() + + # Penalize extra elements not in input + extra_elements = set(output_elements) - set(input_elements) + reward -= len(extra_elements) * 0.5 + + # Check precision + intersection = set(input_elements) & set(output_elements) + precision = len(intersection) / len(input_elements) + if precision == 1: + reward += 3 + else: + reward += precision + + # Try building a composition after applying penalties + try: + from pymatgen.core import Composition + from smact.screening import smact_validity + + comp = Composition(" ".join(output_elements)) + if not smact_validity(comp): + rewards.append(reward) + continue + except Exception as e: + print(f"Invalid composition: {output_elements} -> {e}") + rewards.append(reward) + continue + if precision == 1: + reward += 3 + else: + reward += 1 + + # Novelty bonus + if comp not in self.seen_comps_set: + reward += 2 + self.seen_comps_set.add(comp) + + self.random_log = { + "prompt": prompt, + "output_elements": output_elements, + "output_sg": output_sg, + "accuracy_reward": reward, + "full_completion": completion, + } + self.good_print(self.random_log) + + rewards.append(reward) + self.custom_metrics["val/rewards"].extend(rewards) + return rewards + + def get_metrics(self) -> Dict: + """ + Get task metrics to log in WANDB. + This function takes no arguments and returns a dictionary of metrics {key[str]: value[float]}. + """ + metrics = dict() + if self.log_custom_metrics: + rewards = self.custom_metrics["val/rewards"] + if rewards: + correct_count = sum(1 for r in rewards if r == 1) + total_count = len(rewards) + accuracy = correct_count / total_count if total_count > 0 else 0.0 + metrics["val/accuracy"] = accuracy + self.custom_metrics["val/rewards"] = [] + return metrics