diff --git a/demo/condmatgen_tiny/NatureLM_conditional_v2.json b/demo/condmatgen_tiny/NatureLM_conditional_v2.json
new file mode 100644
index 00000000..767c3769
--- /dev/null
+++ b/demo/condmatgen_tiny/NatureLM_conditional_v2.json
@@ -0,0 +1,52 @@
+[
+ {"elements": "O Te Tm"},
+ {"elements": "Si Al O"},
+ {"elements": "Fe Ni"},
+ {"elements": "Li Co O"},
+ {"elements": "Na Cl"},
+ {"elements": "Ca Ti O"},
+ {"elements": "Mg O"},
+ {"elements": "Zn S"},
+ {"elements": "Cu Fe S"},
+ {"elements": "Ba Ti O"},
+ {"elements": "Sr Ti O"},
+ {"elements": "Ga As"},
+ {"elements": "In P"},
+ {"elements": "Al N"},
+ {"elements": "Si C"},
+ {"elements": "Fe O"},
+ {"elements": "Mn O"},
+ {"elements": "Cr O"},
+ {"elements": "V O"},
+ {"elements": "Ti O"},
+ {"elements": "Li Mn O"},
+ {"elements": "Li Fe P O"},
+ {"elements": "Na Fe O"},
+ {"elements": "K Al Si O"},
+ {"elements": "Ca Mg Si O"},
+ {"elements": "Zr O"},
+ {"elements": "Hf O"},
+ {"elements": "Nb O"},
+ {"elements": "Mo S"},
+ {"elements": "W Se"},
+ {"elements": "Bi Te"},
+ {"elements": "Sb Te"},
+ {"elements": "Pb Te"},
+ {"elements": "Sn Se"},
+ {"elements": "Ge Te"},
+ {"elements": "Cd Se"},
+ {"elements": "Zn Te"},
+ {"elements": "Cu Zn Sn S"},
+ {"elements": "Ag Bi Se"},
+ {"elements": "La Mn O"},
+ {"elements": "Y Ba Cu O"},
+ {"elements": "Nd Fe B"},
+ {"elements": "Sm Co"},
+ {"elements": "Ce O"},
+ {"elements": "Pr Ni O"},
+ {"elements": "Gd Fe O"},
+ {"elements": "Eu Ti O"},
+ {"elements": "Tb Mn O"},
+ {"elements": "Dy Fe O"},
+ {"elements": "Ho Co O"}
+]
diff --git a/demo/fixture_manifest.csv b/demo/fixture_manifest.csv
index 1996d9a7..1c835e69 100644
--- a/demo/fixture_manifest.csv
+++ b/demo/fixture_manifest.csv
@@ -12,3 +12,4 @@ rxn_inversion,demo/datasets/rxn_inversion_sample.csv,50,"~45 train / ~5 test",re
rxn_replacement,demo/datasets/rxn_inversion_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Same schema as rxn_inversion (MCQ with 4 options)"
rxn_naming,demo/datasets/rxn_naming_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Reaction classification into 10 named categories"
rxn_truefalse,demo/datasets/rxn_truefalse_sample.csv,50,"~45 train / ~5 test",repo-local fixture,ready,"Binary true/false reaction validity"
+condmatgen,demo/condmatgen_tiny,50,"~45 train / ~5 test",repo-local fixture,ready,"JSON with element lists for conditional material generation"
diff --git a/demo/run_fixture_smoke.py b/demo/run_fixture_smoke.py
index 4316f664..937de084 100644
--- a/demo/run_fixture_smoke.py
+++ b/demo/run_fixture_smoke.py
@@ -43,6 +43,9 @@ def main():
"rxn_truefalse": datasets_dir / "rxn_truefalse_sample.csv",
}
+ if "condmatgen" in CHEMTASKS:
+ task_configs["condmatgen"] = demo_dir / "condmatgen_tiny"
+
summary = {}
for task_name, dataset_path in task_configs.items():
task_class = CHEMTASKS[task_name]
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
index e9a2c264..a930a00c 100644
--- a/docs/source/modules.rst
+++ b/docs/source/modules.rst
@@ -17,4 +17,5 @@ Modules Reference
tasks/rxn_naming
tasks/rxn_truefalse
tasks/template
- tasks/smi_permute
\ No newline at end of file
+ tasks/smi_permute
+ tasks/condmatgen
\ No newline at end of file
diff --git a/docs/source/tasks/condmatgen.rst b/docs/source/tasks/condmatgen.rst
new file mode 100644
index 00000000..997b5532
--- /dev/null
+++ b/docs/source/tasks/condmatgen.rst
@@ -0,0 +1,26 @@
+Conditional Material Generation (CMG)
+======================================
+
+.. currentmodule:: open_r1.tasks.condmatgen.condmatgen
+
+ConditionalMaterialGeneration
+-----------------------------
+
+.. autoclass:: ConditionalMaterialGeneration
+ :members:
+ :show-inheritance:
+
+Task Description
+----------------
+
+Given a set of chemical elements, the model proposes a novel crystalline
+compound (element list and space group number). The model wraps its reasoning
+in ``...`` tags and its answer in ``...`` tags.
+
+Reward Functions
+----------------
+
+- **accuracy**: multi-component scoring including SMACT validity, element
+ precision, space group validity, and novelty bonus.
+- **format**: checks presence and ordering of think/answer tags, penalizes
+ short reasoning.
diff --git a/recipes/condmatgen.yaml b/recipes/condmatgen.yaml
new file mode 100644
index 00000000..028afba3
--- /dev/null
+++ b/recipes/condmatgen.yaml
@@ -0,0 +1,49 @@
+# Model arguments
+model_revision: main
+torch_dtype: bfloat16
+attn_implementation: flash_attention_2
+bf16: true
+tf32: true
+
+# Chemical Task arguments
+chem_task: condmatgen
+dataset_id_or_path: ${MIST_DATA_DIR}/condmatgen
+rewards:
+- accuracy
+
+# Lora Arguments
+# No LoRA is used here
+
+# Training arguments
+max_steps: 1450
+per_device_train_batch_size: 1
+gradient_accumulation_steps: 8
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+learning_rate: 2.0e-6 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
+lr_scheduler_type: cosine
+warmup_ratio: 0.03
+# GRPO specific parameters
+beta: 0.04 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
+max_prompt_length: 600
+max_completion_length: 2048
+num_generations: 4
+use_vllm: true
+vllm_device: "cuda:3"
+vllm_gpu_memory_utilization: 0.8
+vllm_max_model_len: 2048
+
+# Logging arguments
+logging_strategy: steps
+logging_steps: 1
+report_to:
+- wandb
+
+save_strategy: "steps"
+save_steps: 25
+seed: 42
+
+# Hugging Face Hub
+push_to_hub: false
+ # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir
\ No newline at end of file
diff --git a/src/open_r1/tasks/__init__.py b/src/open_r1/tasks/__init__.py
index 8812ce63..b477254c 100644
--- a/src/open_r1/tasks/__init__.py
+++ b/src/open_r1/tasks/__init__.py
@@ -10,6 +10,11 @@
from .reactions.smi_permute import PermuteSmiles
from .smiles_understanding.smiles_hydrogen import SmilesHydrogen
+try:
+ from .condmatgen.condmatgen import ConditionalMaterialGeneration
+except ImportError:
+ ConditionalMaterialGeneration = None
+
# Task keys as specified in the task recipes and documentation
CHEMTASKS = {
"rxnpred_with_tags": ForwardReactionWithTags,
@@ -26,3 +31,6 @@
"rxn_naming": Smiles2Name,
"rxn_truefalse": ReactionTrueFalse,
}
+
+if ConditionalMaterialGeneration is not None:
+ CHEMTASKS["condmatgen"] = ConditionalMaterialGeneration
diff --git a/src/open_r1/tasks/condmatgen/__init__.py b/src/open_r1/tasks/condmatgen/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/open_r1/tasks/condmatgen/comps_used_in_sft.json b/src/open_r1/tasks/condmatgen/comps_used_in_sft.json
new file mode 100644
index 00000000..0637a088
--- /dev/null
+++ b/src/open_r1/tasks/condmatgen/comps_used_in_sft.json
@@ -0,0 +1 @@
+[]
\ No newline at end of file
diff --git a/src/open_r1/tasks/condmatgen/condmatgen.py b/src/open_r1/tasks/condmatgen/condmatgen.py
new file mode 100644
index 00000000..6862cd17
--- /dev/null
+++ b/src/open_r1/tasks/condmatgen/condmatgen.py
@@ -0,0 +1,249 @@
+import json
+import os
+import re
+from dataclasses import field
+from typing import Any, Dict
+
+from datasets import Dataset, DatasetDict
+
+from open_r1.paths import expand_path
+
+from ..base import RLTask
+
+
+class ConditionalMaterialGeneration(RLTask):
+ question_template: str = ""
+ log_custom_metrics: bool = True
+ custom_metrics: dict = field(default_factory=dict)
+ seen_comps_set: set = field(default_factory=set)
+ random_log: Dict[str, Any] = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.system_prompt = (
+ "You are a careful model that must follow the Output Contract exactly.\n\n"
+ "OUTPUT CONTRACT\n"
+ "1) You may think only inside ....\n"
+ "2) Your final answer must be a single line wrapped in ....\n"
+ "3) Inside , list only element symbols separated by single spaces, "
+ "followed by a space-group tag .\n"
+ "4) If you generate a chemical formula with subscripts (e.g. Ni\u2082Fe\u2084LiO\u2081\u2080), "
+ "you must expand them into repeated symbols "
+ "(Ni Ni Fe Fe Fe Fe Li O O O O O O O O O O).\n"
+ "5) Do not include any extra words, punctuation, examples, explanations, "
+ "or text outside the tags.\n"
+ "6) After you produce , you must stop. No tokens are allowed after .\n\n"
+ "VALID EXAMPLE (format only):\n"
+ "reasoning\n"
+ " Ca O Sn Sn \n\n"
+ "INVALID EXAMPLES:\n"
+ "- Missing tags\n"
+ "- Commas, bullets, or explanations inside \n"
+ "- Multiple blocks\n"
+ "- Extra output after \n\n"
+ "Allowed tokens inside : element symbols (H He Li ... Og), "
+ "single spaces, and the literal pattern ."
+ )
+
+ self.question_template = (
+ "You are a materials science expert.\n"
+ "Given the following elements: {}, propose one chemically valid and novel crystalline compound."
+ )
+
+ self.log_custom_metrics = True
+ self.custom_metrics = {
+ "val/rewards": [],
+ }
+
+ comps_path = os.path.join(os.path.dirname(__file__), "comps_used_in_sft.json")
+ with open(comps_path, "r") as file:
+ seen_comps = json.load(file)
+
+ from pymatgen.core import Composition
+
+ self.seen_comps_set = set()
+ for comp in seen_comps:
+ comp = Composition(comp)
+ self.seen_comps_set.add(comp)
+
+ def read_files(self) -> Dict:
+ dataset_path = expand_path(os.path.join(self.dataset_id_or_path, "NatureLM_conditional_v2.json"))
+ with open(dataset_path, "r") as file:
+ data = json.load(file)
+
+ problems = []
+ solutions = []
+
+ for pt in data:
+ try:
+ problems.append(pt.get("elements"))
+ solutions.append("")
+ except KeyError as e:
+ print(f"Missing expected key in data: {e}")
+
+ return {
+ "problem": problems,
+ "solution": solutions,
+ }
+
+ def generate_prompt(self, problem, tokenizer, **kwargs):
+ r1_prefix = [
+ {"role": "system", "content": self.system_prompt},
+ {
+ "role": "user",
+ "content": self.question_template.format(problem),
+ },
+ ]
+ return {
+ "prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True),
+ "problem": problem,
+ }
+
+ def load(self) -> DatasetDict:
+ """Load and return the complete dataset."""
+ train_dict = self.read_files()
+ train_dataset = Dataset.from_dict(train_dict)
+ seed = 42
+ train_test_split = train_dataset.train_test_split(test_size=0.1, seed=seed)
+ train_dataset = train_test_split["train"]
+ test_dataset = train_test_split["test"]
+
+ self.dataset = DatasetDict({"train": train_dataset, "test": test_dataset})
+ return self.dataset
+
+ def accuracy_reward(self, completions, solution, prompts, **kwargs):
+ """Reward function - check that completion is same as ground truth."""
+ rewards = []
+
+ for completion, prompt in zip(completions, prompts):
+ reward = 0
+
+ # Format
+ think_start = completion.find("")
+ think_end = completion.find("")
+ answer_start = completion.find("")
+ answer_end = completion.find("")
+
+ if think_start != -1:
+ reward += 0.25
+ if think_end != -1:
+ reward += 0.25
+ if answer_start != -1:
+ reward += 0.25
+ if answer_end != -1:
+ reward += 0.25
+
+ if think_start != -1 and think_end != -1:
+ if think_start < think_end:
+ reward += 0.25
+ if answer_start != -1 and answer_end != -1:
+ if answer_start < answer_end:
+ reward += 0.25
+ if think_start != -1 and think_end != -1 and answer_start != -1 and answer_end != -1:
+ if think_start < think_end and answer_start < answer_end and think_end < answer_start:
+ reward += 1
+ else:
+ reward -= 1
+
+ if completion.strip().endswith(""):
+ reward += 1
+ else:
+ reward -= 2
+
+ matches = re.findall(r"(.*?)", completion, flags=re.DOTALL)
+ if matches:
+ reasoning_len = len(matches[-1])
+ else:
+ reasoning_len = 0
+
+ if reasoning_len < 500:
+ reward -= 5
+
+ # Extract elements from instruction
+ input_pattern = r"(?i)elements:\s*(.*?)\s*,\s*propose\b"
+ match = re.search(input_pattern, prompt)
+ input_elements = match.group(1).split(", ") if match else []
+
+ # Extract elements and space group from output
+ output_pattern = r"\s*((?:[A-Z][a-z]?\s*)+?)\s*\s*"
+ output_matches = re.findall(output_pattern, completion)
+ if len(output_matches) < 1:
+ rewards.append(reward)
+ continue
+ elif len(output_matches) == 1:
+ reward += 1
+ else:
+ reward -= 1
+ elements_str, sg_str = output_matches[-1]
+ output_sg = int(sg_str.strip())
+
+ if not 1 <= output_sg <= 230:
+ rewards.append(reward)
+ continue
+ reward += 1
+
+ output_elements = elements_str.strip().split()
+
+ # Penalize extra elements not in input
+ extra_elements = set(output_elements) - set(input_elements)
+ reward -= len(extra_elements) * 0.5
+
+ # Check precision
+ intersection = set(input_elements) & set(output_elements)
+ precision = len(intersection) / len(input_elements)
+ if precision == 1:
+ reward += 3
+ else:
+ reward += precision
+
+ # Try building a composition after applying penalties
+ try:
+ from pymatgen.core import Composition
+ from smact.screening import smact_validity
+
+ comp = Composition(" ".join(output_elements))
+ if not smact_validity(comp):
+ rewards.append(reward)
+ continue
+ except Exception as e:
+ print(f"Invalid composition: {output_elements} -> {e}")
+ rewards.append(reward)
+ continue
+ if precision == 1:
+ reward += 3
+ else:
+ reward += 1
+
+ # Novelty bonus
+ if comp not in self.seen_comps_set:
+ reward += 2
+ self.seen_comps_set.add(comp)
+
+ self.random_log = {
+ "prompt": prompt,
+ "output_elements": output_elements,
+ "output_sg": output_sg,
+ "accuracy_reward": reward,
+ "full_completion": completion,
+ }
+ self.good_print(self.random_log)
+
+ rewards.append(reward)
+ self.custom_metrics["val/rewards"].extend(rewards)
+ return rewards
+
+ def get_metrics(self) -> Dict:
+ """
+ Get task metrics to log in WANDB.
+ This function takes no arguments and returns a dictionary of metrics {key[str]: value[float]}.
+ """
+ metrics = dict()
+ if self.log_custom_metrics:
+ rewards = self.custom_metrics["val/rewards"]
+ if rewards:
+ correct_count = sum(1 for r in rewards if r == 1)
+ total_count = len(rewards)
+ accuracy = correct_count / total_count if total_count > 0 else 0.0
+ metrics["val/accuracy"] = accuracy
+ self.custom_metrics["val/rewards"] = []
+ return metrics