@@ -41,7 +41,7 @@ class Base:
4141
4242 """
4343
44- def __init__ (self , p , q , workers , verbose = False , extend = False ):
44+ def __init__ (self , p , q , workers , verbose = False , extend = False , gamma = 0 ):
4545 """Initializ node2vec base class.
4646
4747 Args:
@@ -53,7 +53,11 @@ def __init__(self, p, q, workers, verbose=False, extend=False):
5353 workers (int): number of threads to be spawned for runing node2vec
5454 including walk generation and word2vec embedding.
5555 verbose (bool): show progress bar for walk generation.
56- extend (bool): ``True`` if use node2vec+ extension, default is ``False``
56+ extend (bool): use node2vec+ extension if set to :obj:`True`
57+ (default: :obj:`False`).
58+ gamma (float): Multiplication factor for the std term of edge
59+ weights added to the average edge weights as the noisy edge
60+ threashold, only used by node2vec+ (default: 0)
5761
5862 """
5963 super ().__init__ ()
@@ -62,6 +66,7 @@ def __init__(self, p, q, workers, verbose=False, extend=False):
6266 self .workers = workers
6367 self .verbose = verbose
6468 self .extend = extend
69+ self .gamma = gamma
6570
6671 def _map_walk (self , walk_idx_ary ):
6772 """Map walk from node index to node ID.
@@ -148,16 +153,16 @@ def setup_get_normalized_probs(self):
148153 probability computation function ``get_extended_normalized_probs``,
149154 if node2vec+ is used. Otherwise, return the normal transition function
150155 ``get_noramlized_probs`` with a trivial placeholder for average edge
151- weights array ``avg_wts ``.
156+ weights array ``noise_thresholds ``.
152157
153158 """
154159 if self .extend : # use n2v+
155160 get_normalized_probs = self .get_extended_normalized_probs
156- avg_wts = self .get_average_weights ()
161+ noise_thresholds = self .get_noise_thresholds ()
157162 else : # use normal n2v
158163 get_normalized_probs = self .get_normalized_probs
159- avg_wts = None
160- return get_normalized_probs , avg_wts
164+ noise_thresholds = None
165+ return get_normalized_probs , noise_thresholds
161166
162167 def preprocess_transition_probs (self ):
163168 """Null default preprocess method."""
@@ -221,9 +226,9 @@ def embed(
221226class FirstOrderUnweighted (Base , SparseRWGraph ):
222227 """Directly sample edges for first order random walks."""
223228
224- def __init__ (self , p , q , workers , verbose = False , extend = False ):
229+ def __init__ (self , * args , ** kwargs ):
225230 """Initialize FirstOrderUnweighted mode."""
226- Base .__init__ (self , p , q , workers , verbose , extend )
231+ Base .__init__ (self , * args , ** kwargs )
227232
228233 def get_move_forward (self ):
229234 """Wrap ``move_forward``."""
@@ -241,9 +246,9 @@ def move_forward(cur_idx, prev_idx=None):
241246class PreCompFirstOrder (Base , SparseRWGraph ):
242247 """Precompute transition probabilities for first order random walks."""
243248
244- def __init__ (self , p , q , workers , verbose = False , extend = False ):
249+ def __init__ (self , * args , ** kwargs ):
245250 """Initialize PreCompFirstOrder mode."""
246- Base .__init__ (self , p , q , workers , verbose , extend )
251+ Base .__init__ (self , * args , ** kwargs )
247252 self .alias_j = self .alias_q = None
248253
249254 def get_move_forward (self ):
@@ -304,9 +309,9 @@ class PreComp(Base, SparseRWGraph):
304309
305310 """
306311
307- def __init__ (self , p , q , workers , verbose = False , extend = False ):
312+ def __init__ (self , * args , ** kwargs ):
308313 """Initialize PreComp mode node2vec."""
309- Base .__init__ (self , p , q , workers , verbose , extend )
314+ Base .__init__ (self , * args , ** kwargs )
310315 self .alias_j = self .alias_q = self .alias_indptr = self .alias_dim = None
311316
312317 def get_move_forward (self ):
@@ -390,7 +395,7 @@ def preprocess_transition_probs(self):
390395 q = self .q
391396
392397 # Retrieve transition probability computation callback function
393- get_normalized_probs , avg_wts = self .setup_get_normalized_probs ()
398+ get_normalized_probs , noise_thresholds = self .setup_get_normalized_probs ()
394399
395400 # Determine the dimensionality of the 2nd order transition probs
396401 n_nodes = self .indptr .size - 1 # number of nodes
@@ -423,7 +428,7 @@ def compute_all_transition_probs():
423428 q ,
424429 idx ,
425430 nbr ,
426- avg_wts ,
431+ noise_thresholds ,
427432 )
428433
429434 start = offset + dim * nbr_idx
@@ -444,9 +449,9 @@ class SparseOTF(Base, SparseRWGraph):
444449
445450 """
446451
447- def __init__ (self , p , q , workers , verbose = False , extend = False ):
452+ def __init__ (self , * args , ** kwargs ):
448453 """Initialize PreComp mode node2vec."""
449- Base .__init__ (self , p , q , workers , verbose , extend )
454+ Base .__init__ (self , * args , ** kwargs )
450455
451456 def get_move_forward (self ):
452457 """Wrap ``move_forward``.
@@ -467,7 +472,7 @@ def get_move_forward(self):
467472 p = self .p
468473 q = self .q
469474
470- get_normalized_probs , avg_wts = self .setup_get_normalized_probs ()
475+ get_normalized_probs , noise_thresholds = self .setup_get_normalized_probs ()
471476
472477 @njit (nogil = True )
473478 def move_forward (cur_idx , prev_idx = None ):
@@ -480,7 +485,7 @@ def move_forward(cur_idx, prev_idx=None):
480485 q ,
481486 cur_idx ,
482487 prev_idx ,
483- avg_wts ,
488+ noise_thresholds ,
484489 )
485490 cdf = np .cumsum (normalized_probs )
486491 choice = np .searchsorted (cdf , np .random .random ())
@@ -499,9 +504,9 @@ class DenseOTF(Base, DenseRWGraph):
499504
500505 """
501506
502- def __init__ (self , p , q , workers , verbose = False , extend = False ):
507+ def __init__ (self , * args , ** kwargs ):
503508 """Initialize DenseOTF mode node2vec."""
504- Base .__init__ (self , p , q , workers , verbose , extend )
509+ Base .__init__ (self , * args , ** kwargs )
505510
506511 def get_move_forward (self ):
507512 """Wrap ``move_forward``.
@@ -521,7 +526,7 @@ def get_move_forward(self):
521526 p = self .p
522527 q = self .q
523528
524- get_normalized_probs , avg_wts = self .setup_get_normalized_probs ()
529+ get_normalized_probs , noise_thresholds = self .setup_get_normalized_probs ()
525530
526531 @njit (nogil = True )
527532 def move_forward (cur_idx , prev_idx = None ):
@@ -533,7 +538,7 @@ def move_forward(cur_idx, prev_idx=None):
533538 q ,
534539 cur_idx ,
535540 prev_idx ,
536- avg_wts ,
541+ noise_thresholds ,
537542 )
538543 cdf = np .cumsum (normalized_probs )
539544 choice = np .searchsorted (cdf , np .random .random ())
0 commit comments