@@ -66,6 +66,7 @@ def __init__(
6666
6767 def convert_task (self , task : TrinityTask ):
6868 from ajet .schema .task import Task
69+
6970 assert isinstance (task .raw_task , dict )
7071 return dict_to_ajet_task (task .raw_task )
7172
@@ -150,16 +151,10 @@ async def run_async(self):
150151 "madness" : tracker .reward_structure .madness ,
151152 }
152153
153- if (
154- len (response_ids ) + len (prompt_ids ) == len (input_ids )
155- and len (logprobs ) == len (response_ids )
156- and len (logprobs ) > 0
157- ):
154+ if len (response_ids ) + len (prompt_ids ) == len (input_ids ) and len (logprobs ) == len (response_ids ) and len (logprobs ) > 0 :
158155 exp = Experience (
159156 tokens = input_ids , # [seq_length] prompt + response
160- prompt_length = len (
161- prompt_ids
162- ), # Length of the prompt in tokens, used for generating attention masks
157+ prompt_length = len (prompt_ids ), # Length of the prompt in tokens, used for generating attention masks
163158 logprobs = logprobs , # [resp_length]
164159 reward = reward , #
165160 # advantages=None,
@@ -211,19 +206,11 @@ def __init__(self, config):
211206 if "train" in self .split :
212207 dataset_segments .append (task_to_standard_dataset (task_reader .get_training_tasks ()))
213208 if "val" in self .split :
214- dataset_segments .append (
215- task_to_standard_dataset (task_reader .get_validation_tasks ())
216- )
209+ dataset_segments .append (task_to_standard_dataset (task_reader .get_validation_tasks ()))
217210 if not dataset_segments :
218- raise ValueError (
219- f"Unsupported split '{ self .split } '. Expected to contain 'train' or 'val'."
220- )
211+ raise ValueError (f"Unsupported split '{ self .split } '. Expected to contain 'train' or 'val'." )
221212
222- concatenated_dataset = (
223- dataset_segments [0 ]
224- if len (dataset_segments ) == 1
225- else datasets .concatenate_datasets (dataset_segments )
226- )
213+ concatenated_dataset = dataset_segments [0 ] if len (dataset_segments ) == 1 else datasets .concatenate_datasets (dataset_segments )
227214
228215 self .dataset = _HFBatchReader (
229216 concatenated_dataset ,
@@ -271,15 +258,9 @@ class SwanlabMonitor(Monitor):
271258 """
272259
273260 def __init__ (self , project : str , group : str , name : str , role : str , config ) -> None :
274- assert (
275- swanlab is not None
276- ), "swanlab is not installed. Please install it to use SwanlabMonitor."
277-
278- monitor_args = (
279- (config .monitor .monitor_args or {})
280- if config and getattr (config , "monitor" , None )
281- else {}
282- )
261+ assert swanlab is not None , "swanlab is not installed. Please install it to use SwanlabMonitor."
262+
263+ monitor_args = (config .monitor .monitor_args or {}) if config and getattr (config , "monitor" , None ) else {}
283264
284265 # Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`.
285266 api_key = os .environ .get ("SWANLAB_API_KEY" )
@@ -331,9 +312,7 @@ def __init__(self, project: str, group: str, name: str, role: str, config) -> No
331312 self .data_dashboard_url = run_info ["cloud" ]["experiment_url" ]
332313
333314 def log_table (self , table_name : str , experiences_table , step : int ):
334- assert (
335- swanlab is not None
336- ), "swanlab is not installed. Please install it to use SwanlabMonitor."
315+ assert swanlab is not None , "swanlab is not installed. Please install it to use SwanlabMonitor."
337316
338317 # Convert pandas DataFrame to SwanLab ECharts Table
339318 headers : List [str ] = list (experiences_table .columns )
@@ -351,9 +330,7 @@ def log_table(self, table_name: str, experiences_table, step: int):
351330 def log (self , data : dict , step : int , commit : bool = False ) -> None :
352331 """Log metrics."""
353332 # SwanLab doesn't use commit flag; keep signature for compatibility
354- assert (
355- swanlab is not None
356- ), "swanlab is not installed. Please install it to use SwanlabMonitor."
333+ assert swanlab is not None , "swanlab is not installed. Please install it to use SwanlabMonitor."
357334 swanlab .log (data , step = step )
358335 self .console_logger .info (f"Step { step } : { data } " )
359336
@@ -372,9 +349,7 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
372349 test_robot_data = {}
373350 test_robot_data ["step" ] = step
374351 test_robot_data ["data_dashboard_url" ] = self .data_dashboard_url
375- test_robot_data ["reward_for_test_robot" ] = data [
376- "experience_pipeline/group_advantages/reward_mean/mean"
377- ]
352+ test_robot_data ["reward_for_test_robot" ] = data ["experience_pipeline/group_advantages/reward_mean/mean" ]
378353 _test_if_test_mode (key = "reward_probe" , value = test_robot_data , config = ajet_config )
379354
380355 def close (self ) -> None :
0 commit comments