Skip to content

Commit 45f4288

Browse files
authored
Merge pull request #979 from ottointhesky/task_label_feature
Add task label
2 parents c4bdfaf + 5bd8530 commit 45f4288

File tree

7 files changed

+136
-10
lines changed

7 files changed

+136
-10
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Basic task label example"""
2+
3+
import ipyparallel as ipp
4+
5+
# start up ipp cluster with 2 engines
6+
cluster = ipp.Cluster(n=2)
7+
cluster.start_cluster_sync()
8+
9+
rc = cluster.connect_client_sync()
10+
rc.wait_for_engines(n=2)
11+
12+
13+
def wait(t):
14+
import time
15+
16+
tic = time.time()
17+
time.sleep(t)
18+
return time.time() - tic
19+
20+
21+
# use load balanced view
22+
bview = rc.load_balanced_view()
23+
ar_list_b1 = [
24+
bview.set_flags(label=f"mylabel_map_{i:02}").map_async(wait, [2]) for i in range(10)
25+
]
26+
ar_list_b2 = [
27+
bview.set_flags(label=f"mylabel_map_{i:02}").apply_async(wait, 2) for i in range(10)
28+
]
29+
bview.wait(ar_list_b1)
30+
bview.wait(ar_list_b2)
31+
32+
33+
# use direct view
34+
dview = rc[:]
35+
ar_list_d1 = [
36+
dview.set_flags(label=f"mylabel_map_{i + 10:02}").apply_async(wait, 2)
37+
for i in range(10)
38+
]
39+
ar_list_d2 = [
40+
dview.set_flags(label=f"mylabel_map_{i + 10:02}").map_async(wait, [2])
41+
for i in range(10)
42+
]
43+
dview.wait(ar_list_d1)
44+
dview.wait(ar_list_d2)
45+
46+
# query database
47+
data = rc.db_query({'label': {"$nin": ""}}, keys=['msg_id', 'label', 'engine_uuid'])
48+
for d in data:
49+
print(f"msg_id={d['msg_id']}; label={d['label']}; engine_uuid={d['engine_uuid']}")
50+
51+
cluster.stop_cluster_sync()

ipyparallel/client/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(self, *args, **kwargs):
220220
'stderr': '',
221221
'outputs': [],
222222
'data': {},
223+
'label': None,
223224
}
224225
self.update(md)
225226
self.update(dict(*args, **kwargs))
@@ -822,6 +823,7 @@ def _extract_metadata(self, msg):
822823
'status': content['status'],
823824
'is_broadcast': msg_meta.get('is_broadcast', False),
824825
'is_coalescing': msg_meta.get('is_coalescing', False),
826+
'label': msg_meta.get('label', None),
825827
}
826828

827829
if md['engine_uuid'] is not None:

ipyparallel/client/view.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ class View(HasTraits):
9898
block = Bool(False)
9999
track = Bool(False)
100100
targets = Any()
101+
label = Any()
101102

102103
history = List()
103104
outstanding = Set()
104105
results = Dict()
105106
client = Instance('ipyparallel.Client', allow_none=True)
106107

107108
_socket = Any()
108-
_flag_names = List(['targets', 'block', 'track'])
109+
_flag_names = List(['targets', 'block', 'track', 'label'])
109110
_in_sync_results = Bool(False)
110111
_targets = Any()
111112
_idents = Any()
@@ -155,6 +156,8 @@ def set_flags(self, **kwargs):
155156
else:
156157
setattr(self, name, value)
157158

159+
return self # returning self would allow direct calling of map/apply in one command (no context manager)
160+
158161
@contextmanager
159162
def temp_flags(self, **kwargs):
160163
"""temporarily set flags, for use in `with` statements.
@@ -530,7 +533,14 @@ def use_pickle(self):
530533
@sync_results
531534
@save_ids
532535
def _really_apply(
533-
self, f, args=None, kwargs=None, targets=None, block=None, track=None
536+
self,
537+
f,
538+
args=None,
539+
kwargs=None,
540+
targets=None,
541+
block=None,
542+
track=None,
543+
label=None,
534544
):
535545
"""calls f(*args, **kwargs) on remote engines, returning the result.
536546
@@ -562,6 +572,8 @@ def _really_apply(
562572
block = self.block if block is None else block
563573
track = self.track if track is None else track
564574
targets = self.targets if targets is None else targets
575+
label = self.label if label is None else label
576+
metadata = dict(label=label)
565577

566578
_idents, _targets = self.client._build_targets(targets)
567579
futures = []
@@ -572,7 +584,13 @@ def _really_apply(
572584

573585
for ident in _idents:
574586
future = self.client.send_apply_request(
575-
self._socket, pf, pargs, pkwargs, track=track, ident=ident
587+
self._socket,
588+
pf,
589+
pargs,
590+
pkwargs,
591+
track=track,
592+
ident=ident,
593+
metadata=metadata,
576594
)
577595
futures.append(future)
578596
if track:
@@ -592,7 +610,15 @@ def _really_apply(
592610
return ar
593611

594612
@sync_results
595-
def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
613+
def map(
614+
self,
615+
f,
616+
*sequences,
617+
block=None,
618+
track=False,
619+
return_exceptions=False,
620+
label=None,
621+
):
596622
"""Parallel version of builtin `map`, using this View's `targets`.
597623
598624
There will be one task per target, so work will be chunked
@@ -630,10 +656,17 @@ def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
630656

631657
if block is None:
632658
block = self.block
659+
if label is None:
660+
label = self.label
633661

634662
assert len(sequences) > 0, "must have some sequences to map onto!"
635663
pf = ParallelFunction(
636-
self, f, block=block, track=track, return_exceptions=return_exceptions
664+
self,
665+
f,
666+
block=block,
667+
track=track,
668+
return_exceptions=return_exceptions,
669+
label=label,
637670
)
638671
return pf.map(*sequences)
639672

@@ -1036,7 +1069,15 @@ def _broadcast_map(f, *sequence_names):
10361069
return list(map(f, *sequences))
10371070

10381071
@_not_coalescing
1039-
def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
1072+
def map(
1073+
self,
1074+
f,
1075+
*sequences,
1076+
block=None,
1077+
track=False,
1078+
return_exceptions=False,
1079+
label=None,
1080+
):
10401081
"""Parallel version of builtin `map`, using this View's `targets`.
10411082
10421083
There will be one task per engine, so work will be chunked
@@ -1176,10 +1217,11 @@ class LoadBalancedView(View):
11761217
after = Any()
11771218
timeout = CFloat()
11781219
retries = Integer(0)
1220+
label = Any()
11791221

11801222
_task_scheme = Any()
11811223
_flag_names = List(
1182-
['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries']
1224+
['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries', 'label']
11831225
)
11841226
_outstanding_maps = Set()
11851227

@@ -1275,6 +1317,8 @@ def set_flags(self, **kwargs):
12751317

12761318
self.timeout = t
12771319

1320+
return self # returning self would allow direct calling of map/apply in one command (no context manager)
1321+
12781322
@sync_results
12791323
@save_ids
12801324
def _really_apply(
@@ -1289,6 +1333,7 @@ def _really_apply(
12891333
timeout=None,
12901334
targets=None,
12911335
retries=None,
1336+
label=None,
12921337
):
12931338
"""calls f(*args, **kwargs) on a remote engine, returning the result.
12941339
@@ -1344,6 +1389,7 @@ def _really_apply(
13441389
follow = self.follow if follow is None else follow
13451390
timeout = self.timeout if timeout is None else timeout
13461391
targets = self.targets if targets is None else targets
1392+
label = self.label if label is None else label
13471393

13481394
if not isinstance(retries, int):
13491395
raise TypeError(f'retries must be int, not {type(retries)!r}')
@@ -1358,7 +1404,12 @@ def _really_apply(
13581404
after = self._render_dependency(after)
13591405
follow = self._render_dependency(follow)
13601406
metadata = dict(
1361-
after=after, follow=follow, timeout=timeout, targets=idents, retries=retries
1407+
after=after,
1408+
follow=follow,
1409+
timeout=timeout,
1410+
targets=idents,
1411+
retries=retries,
1412+
label=label,
13621413
)
13631414

13641415
future = self.client.send_apply_request(
@@ -1389,6 +1440,7 @@ def map(
13891440
chunksize=1,
13901441
ordered=True,
13911442
return_exceptions=False,
1443+
label=None,
13921444
):
13931445
"""Parallel version of builtin `map`, load-balanced by this View.
13941446
@@ -1433,6 +1485,8 @@ def map(
14331485
# default
14341486
if block is None:
14351487
block = self.block
1488+
if label is None:
1489+
label = self.label
14361490

14371491
assert len(sequences) > 0, "must have some sequences to map onto!"
14381492

@@ -1443,6 +1497,7 @@ def map(
14431497
chunksize=chunksize,
14441498
ordered=ordered,
14451499
return_exceptions=return_exceptions,
1500+
label=label,
14461501
)
14471502
return pf.map(*sequences)
14481503

ipyparallel/controller/hub.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def empty_record():
7575
'error': None,
7676
'stdout': '',
7777
'stderr': '',
78+
'label': None,
7879
}
7980

8081

@@ -111,6 +112,7 @@ def init_record(msg):
111112
'error': None,
112113
'stdout': '',
113114
'stderr': '',
115+
'label': msg['metadata'].get('label', None),
114116
}
115117

116118

ipyparallel/controller/mongodb.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""A TaskRecord backend using mongodb"""
22

33
try:
4-
from pymongo import MongoClient
4+
from pymongo import MongoClient, version
55
except ImportError:
66
from pymongo import Connection as MongoClient
77

@@ -15,6 +15,11 @@
1515

1616
from .dictdb import BaseDB
1717

18+
# we need to determine the pymongo version because of API changes. see
19+
# https://pymongo.readthedocs.io/en/stable/migrate-to-pymongo4.html
20+
pymongo_version_major = int(version.split('.')[0])
21+
pymongo_version_minor = int(version.split('.')[1])
22+
1823
# -----------------------------------------------------------------------------
1924
# MongoDB class
2025
# -----------------------------------------------------------------------------
@@ -56,6 +61,13 @@ def __init__(self, **kwargs):
5661
self.database = self.session
5762
self._db = self._connection[self.database]
5863
self._records = self._db['task_records']
64+
if pymongo_version_major >= 4:
65+
# mimic the old API 3.x
66+
self._records.insert = self._records.insert_one
67+
self._records.update = self._records.update_one
68+
self._records.ensure_index = self._records.create_index
69+
self._records.remove = self._records.delete_many
70+
5971
self._records.ensure_index('msg_id', unique=True)
6072
self._records.ensure_index('submitted') # for sorting history
6173
# for rec in self._records.find

ipyparallel/controller/sqlitedb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class SQLiteDB(BaseDB):
155155
'error',
156156
'stdout',
157157
'stderr',
158+
'label',
158159
]
159160
)
160161
# sqlite datatypes for checking that db is current format
@@ -182,6 +183,7 @@ class SQLiteDB(BaseDB):
182183
'error': 'text',
183184
'stdout': 'text',
184185
'stderr': 'text',
186+
'label': 'text',
185187
}
186188
)
187189

@@ -303,7 +305,8 @@ def _init_db(self):
303305
execute_result text,
304306
error text,
305307
stdout text,
306-
stderr text)
308+
stderr text,
309+
label text)
307310
"""
308311
)
309312
self._db.commit()

ipyparallel/engine/kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def init_metadata(self, parent):
7676
'is_broadcast': parent_metadata.get('is_broadcast', False),
7777
'is_coalescing': parent_metadata.get('is_coalescing', False),
7878
'original_msg_id': parent_metadata.get('original_msg_id', ''),
79+
'label': parent_metadata.get('label', None),
7980
}
8081

8182
def finish_metadata(self, parent, metadata, reply_content):

0 commit comments

Comments
 (0)