-
Notifications
You must be signed in to change notification settings - Fork 49
Expand file tree
/
Copy pathalfworld_eval.py
More file actions
56 lines (42 loc) · 1.34 KB
/
alfworld_eval.py
File metadata and controls
56 lines (42 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import sys
import subprocess
from termcolor import cprint
from omegaconf import DictConfig, ListConfig, OmegaConf
def get_config():
cli_conf = OmegaConf.from_cli()
yaml_conf = OmegaConf.load(cli_conf.config)
conf = OmegaConf.merge(yaml_conf, cli_conf)
return conf
if __name__ == "__main__":
config = get_config()
project_name = config.experiment.project
env_type = config.dataset.environment_type
def begin_with(file_name):
with open(file_name, "w") as f:
f.write("")
def sample():
cprint(f"This is sampling.", color = "green")
if env_type == "alfworld":
script_name = "alfworld_sample.py"
subprocess.run(
f'python {script_name} '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='sample',
check=True,
)
def reward():
cprint(f"This is the rewarding.", color = "green")
if env_type == "alfworld":
script_name = "alfworld_reward.py"
subprocess.run(
f'python {script_name} '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='reward',
check=True,
)
os.makedirs(f"{project_name}/results", exist_ok=True)
sample()
reward()