forked from InexperiencedMe/NaturalDreamer
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
61 lines (49 loc) · 3.59 KB
/
main.py
File metadata and controls
61 lines (49 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gymnasium as gym
import torch
import argparse
import os
from dreamer import Dreamer
from utils import loadConfig, seedEverything, plotMetrics
from envs import getEnvProperties, GymPixelsProcessingWrapper, CleanGymWrapper
from utils import saveLossesToCSV, ensureParentFolders
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(configFile):
config = loadConfig(configFile)
seedEverything(config.seed)
runName = f"{config.environmentName}_{config.runName}"
checkpointToLoad = os.path.join(config.folderNames.checkpointsFolder, f"{runName}_{config.checkpointToLoad}")
metricsFilename = os.path.join(config.folderNames.metricsFolder, runName)
plotFilename = os.path.join(config.folderNames.plotsFolder, runName)
checkpointFilenameBase = os.path.join(config.folderNames.checkpointsFolder, runName)
videoFilenameBase = os.path.join(config.folderNames.videosFolder, runName)
ensureParentFolders(metricsFilename, plotFilename, checkpointFilenameBase, videoFilenameBase)
env = CleanGymWrapper(GymPixelsProcessingWrapper(gym.wrappers.ResizeObservation(gym.make(config.environmentName), (64, 64))))
envEvaluation = CleanGymWrapper(GymPixelsProcessingWrapper(gym.wrappers.ResizeObservation(gym.make(config.environmentName, render_mode="rgb_array"), (64, 64))))
observationShape, actionSize, actionLow, actionHigh = getEnvProperties(env)
print(f"envProperties: obs {observationShape}, action size {actionSize}, actionLow {actionLow}, actionHigh {actionHigh}")
dreamer = Dreamer(observationShape, actionSize, actionLow, actionHigh, device, config.dreamer)
if config.resume:
dreamer.loadCheckpoint(checkpointToLoad)
dreamer.environmentInteraction(env, config.episodesBeforeStart, seed=config.seed)
iterationsNum = config.gradientSteps // config.replayRatio
for _ in tqdm(range(iterationsNum)):
for _ in range(config.replayRatio):
sampledData = dreamer.buffer.sample(dreamer.config.batchSize, dreamer.config.batchLength)
initialStates, worldModelMetrics = dreamer.worldModelTraining(sampledData)
behaviorMetrics = dreamer.behaviorTraining(initialStates)
dreamer.totalGradientSteps += 1
if dreamer.totalGradientSteps % config.checkpointInterval == 0 and config.saveCheckpoints:
suffix = f"{dreamer.totalGradientSteps/1000:.0f}k"
dreamer.saveCheckpoint(f"{checkpointFilenameBase}_{suffix}")
evaluationScore = dreamer.environmentInteraction(envEvaluation, config.numEvaluationEpisodes, seed=config.seed, evaluation=True, saveVideo=True, filename=f"{videoFilenameBase}_{suffix}")
print(f"Saved Checkpoint and Video at {suffix:>6} gradient steps. Evaluation score: {evaluationScore:>8.2f}")
mostRecentScore = dreamer.environmentInteraction(env, config.numInteractionEpisodes, seed=config.seed)
if config.saveMetrics:
metricsBase = {"envSteps": dreamer.totalEnvSteps, "gradientSteps": dreamer.totalGradientSteps, "totalReward" : mostRecentScore}
saveLossesToCSV(metricsFilename, metricsBase | worldModelMetrics | behaviorMetrics)
plotMetrics(f"{metricsFilename}", savePath=f"{plotFilename}", title=f"{config.environmentName}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="car-racing-v3.yml")
main(parser.parse_args().config)