@@ -47,7 +47,13 @@ def multiclass_dataset():
4747
4848def test_gcn_scorer_multilabel (multilabel_dataset ):
4949 torch .manual_seed (42 )
50- scorer = GCNScorer (embedder_config = _embedder_config , label_embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
50+ scorer = GCNScorer (
51+ embedder_config = _embedder_config ,
52+ label_embedder_config = _embedder_config ,
53+ num_train_epochs = 1 ,
54+ batch_size = 2 ,
55+ seed = 42 ,
56+ )
5157 train_utterances = multilabel_dataset ["train" ]["utterance" ]
5258 train_labels = multilabel_dataset ["train" ]["label" ]
5359 descriptions = [intent .name for intent in multilabel_dataset .intents ]
@@ -62,7 +68,13 @@ def test_gcn_scorer_multilabel(multilabel_dataset):
6268
6369def test_gcn_scorer_multiclass (multiclass_dataset ):
6470 torch .manual_seed (42 )
65- scorer = GCNScorer (embedder_config = _embedder_config , label_embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
71+ scorer = GCNScorer (
72+ embedder_config = _embedder_config ,
73+ label_embedder_config = _embedder_config ,
74+ num_train_epochs = 1 ,
75+ batch_size = 2 ,
76+ seed = 42 ,
77+ )
6678 train_utterances = multiclass_dataset ["train" ]["utterance" ]
6779 train_labels = multiclass_dataset ["train" ]["label" ]
6880 descriptions = [intent .name for intent in multiclass_dataset .intents ]
@@ -78,7 +90,13 @@ def test_gcn_scorer_multiclass(multiclass_dataset):
7890
7991def test_gcn_scorer_dump_load (tmp_path , multilabel_dataset ):
8092 torch .manual_seed (42 )
81- scorer = GCNScorer (embedder_config = _embedder_config , label_embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
93+ scorer = GCNScorer (
94+ embedder_config = _embedder_config ,
95+ label_embedder_config = _embedder_config ,
96+ num_train_epochs = 1 ,
97+ batch_size = 2 ,
98+ seed = 42 ,
99+ )
82100 train_utterances = multilabel_dataset ["train" ]["utterance" ]
83101 train_labels = multilabel_dataset ["train" ]["label" ]
84102 descriptions = [intent .name for intent in multilabel_dataset .intents ]
0 commit comments