-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_mlmodelscope.py
More file actions
334 lines (287 loc) · 15.4 KB
/
run_mlmodelscope.py
File metadata and controls
334 lines (287 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import argparse
import json
import os
import time
from datetime import datetime, timezone
from uuid import uuid4
from typing import Dict, Any
import numpy as np
from mlmodelscope import MLModelScope
# Constants
TRACE_LEVEL = (
"NO_TRACE",
"APPLICATION_TRACE",
"MODEL_TRACE", # pipelines within model
"FRAMEWORK_TRACE", # layers within framework
"ML_LIBRARY_TRACE", # cudnn, ...
"SYSTEM_LIBRARY_TRACE", # cupti
"HARDWARE_TRACE", # perf, papi, ...
"FULL_TRACE" # includes all of the above
)
def parse_args():
parser = argparse.ArgumentParser(description="mlmodelscope")
parser.add_argument("--standalone", type=str, nargs='?', default="true", choices=["false", "true"], help="Whether standalone(not connect with frontend)")
parser.add_argument("--agent", type=str, nargs='?', default="pytorch", choices=["pytorch", "tensorflow", "onnxruntime", "mxnet", "jax", "max"], help="Which framework to use")
if parser.parse_known_args()[0].standalone == 'true':
parser.add_argument("--user", type=str, nargs='?', default="default", help="The name of the user")
parser.add_argument("--task", type=str, nargs='?', default="image_classification", help="The name of the task to predict.")
parser.add_argument("--model_name", type=str, nargs='?', default="torchvision_alexnet", help="The name of the model")
parser.add_argument("--config_file", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to use config file (.json)")
parser.add_argument("--config_file_path", type=str, nargs='?', default="config.json", help="The path of the config file")
parser.add_argument("--architecture", type=str, nargs='?', default="gpu", choices=["cpu", "gpu"], help="Which Processing Unit to use")
parser.add_argument("--num_warmup", type=int, nargs='?', default=2, help="Total number of warmup steps for predict.")
parser.add_argument("--dataset_name", type=str, nargs='?', default="test_data", help="The name of the dataset for predict.")
parser.add_argument("--batch_size", type=int, nargs='?', default=2, help="Total batch size for predict.")
parser.add_argument("--trace_level", type=str, nargs='?', default="NO_TRACE", choices=TRACE_LEVEL, help="MLModelScope Trace Level")
parser.add_argument("--gpu_trace", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to trace GPU activities")
parser.add_argument("--cuda_runtime_driver_time_adjustment", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to adjust the CUDA Runtime/Driver time")
parser.add_argument("--save_trace_result", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to save the trace result")
parser.add_argument("--save_trace_result_path", type=str, nargs='?', default="trace_result.txt", help="The path of the trace result file")
parser.add_argument("--detailed_result", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to get detailed result")
parser.add_argument("--security_check", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to perform security check on the model file")
parser.add_argument("--save_output", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to save the output")
parser.add_argument("--save_output_path", type=str, nargs='?', default="output.json", help="The path of the output file")
else:
parser.add_argument("--env_file", type=str, nargs='?', default="false", choices=["false", "true"], help="Whether to use env file")
if parser.parse_known_args()[0].env_file == 'false':
parser.add_argument("--db_dbname", type=str, nargs='?', required=True, help="The name of the database")
parser.add_argument("--db_host", type=str, nargs='?', default='localhost', help="The host of the database")
parser.add_argument("--db_port", type=int, nargs='?', default=15432, help="The port of the database")
parser.add_argument("--db_user", type=str, nargs='?', required=True, help="The user of the database")
parser.add_argument("--db_password", type=str, nargs='?', required=True, help="The password of the database")
parser.add_argument("--mq_name", type=str, nargs='?', required=True, help="The name of the messagequeue")
parser.add_argument("--mq_host", type=str, nargs='?', default='localhost', help="The user of the messagequeue")
parser.add_argument("--mq_port", type=int, nargs='?', default=5672, help="The port of the messagequeue")
parser.add_argument("--mq_user", type=str, nargs='?', required=True, help="The user of the messagequeue")
parser.add_argument("--mq_password", type=str, nargs='?', required=True, help="The password of the messagequeue")
return parser.parse_args()
class DatabaseConnection:
def __init__(self, host: str, dbname: str, user: str, password: str, port: int):
self.conn = self._connect_to_db(host, dbname, user, password, port)
self.cur = self.conn.cursor()
@staticmethod
def _connect_to_db(host: str, dbname: str, user: str, password: str, port: int):
try:
import psycopg
except ImportError:
raise ImportError("Please install psycopg to use database functionality")
return psycopg.connect(
f"host={host} dbname={dbname} user={user} password={password} port={port}"
)
def update_trial(self, trial_id: str, result: Dict[str, Any]) -> None:
dt = datetime.now(timezone.utc)
query = f"UPDATE trials SET updated_at = %s, completed_at = %s, result = %s WHERE id = %s;"
try:
# 'single quote' produces errors in frontend
self.cur.execute(query, (dt, dt, str(json.dumps(result)), trial_id))
except Exception as e:
print(f"Database error: {e}")
self.conn.rollback()
else:
self.conn.commit()
def close(self) -> None:
self.conn.close()
class MessageQueueHandler:
def __init__(self, host: str, port: int, user: str, password: str, queue_name: str):
self.connection, self.channel = self._connect_to_queue(host, port, user, password, queue_name)
self.queue_name = queue_name
@staticmethod
def _connect_to_queue(host: str, port: int, user: str, password: str, queue_name: str):
try:
import pika
except ImportError:
raise ImportError("Please install pika to use message queue functionality")
credentials = pika.PlainCredentials(user, password)
parameters = pika.ConnectionParameters(host=host, port=port, credentials=credentials)
connection = pika.BlockingConnection(parameters)
channel = connection.channel()
channel.queue_declare(queue=queue_name)
return connection, channel
def start_consuming(self, callback) -> None:
self.channel.basic_consume(
queue=self.queue_name,
on_message_callback=callback,
auto_ack=True
)
try:
print("Waiting for messages. To exit press CTRL+C")
self.channel.start_consuming()
except KeyboardInterrupt:
print("Exiting...")
self.channel.stop_consuming()
def close(self) -> None:
self.connection.close()
def process_message(db_conn: DatabaseConnection, body: bytes, properties, agent: str) -> None:
received_message = json.loads(body.decode())
# Extract message parameters
user = received_message.get('User', 'default')
task = received_message['DesiredResultModality']
architecture = 'gpu' if received_message['UseGpu'] else 'cpu'
gpu_trace = received_message['UseGpu'] != "NO_TRACE"
if architecture != "gpu" and gpu_trace:
gpu_trace = False
print("GPU trace disabled for CPU architecture")
# model_name = received_message['ModelName'][:-4].lower().replace('.', '_')
model_name = received_message['ModelName'].lower().replace('.', '_')
num_warmup = received_message.get('NumWarmup', 0)
dataset_name = received_message['InputFiles']
batch_size = received_message.get('BatchSize', 1)
config = received_message.get('Config', None)
security_check = received_message.get('SecurityCheck', False)
duration_start = time.time()
mlms = MLModelScope(architecture, received_message.get('TraceLevel', 'NO_TRACE'), gpu_trace)
mlms.load_agent(task, agent, model_name, security_check, config, user)
print(f"{agent}-agent loaded with {model_name} model")
mlms.load_dataset(dataset_name, batch_size, None, security_check)
print(f"{dataset_name} dataset loaded")
print("Prediction starts")
duration_for_inference_start_time = time.time()
outputs = mlms.predict(num_warmup, True)
duration_for_inference_end_time = time.time()
duration_for_inference = (duration_for_inference_end_time - duration_for_inference_start_time)
print("Prediction done")
mlms.Close()
duration_end_time = time.time()
duration = (duration_end_time - duration_start)
# result = {'duration': duration, 'duration_for_inference': duration_for_inference, 'responses': outputs}
result = outputs[0]
result["duration"] = f"{duration:.10f}s"
result["duration_for_inference"] = f"{duration_for_inference:.10f}s"
result["responses"][0]["id"] = str(uuid4())
db_conn.update_trial(properties.correlation_id, result)
def run_standalone(args):
agent = args.agent
user = args.user
task = args.task
architecture = args.architecture
trace_level = args.trace_level
gpu_trace = TRACE_LEVEL.index(trace_level) >= TRACE_LEVEL.index("SYSTEM_LIBRARY_TRACE") and args.gpu_trace == "true"
if architecture != "gpu" and gpu_trace:
gpu_trace = False
print(f"GPU device will not be used because \"{architecture}\" is chosen for architecture.\nTherefore, gpu_trace option becomes off.")
model_name = args.model_name
config = load_config(args.config_file, args.config_file_path)
num_warmup = args.num_warmup
dataset_name = args.dataset_name
batch_size = args.batch_size
detailed = args.detailed_result == "true"
security_check = args.security_check == "true"
save_trace_result_path = args.save_trace_result_path if args.save_trace_result == "true" and trace_level != "NO_TRACE" else None
save_output_path = args.save_output_path if args.save_output == "true" else None
mlms = MLModelScope(architecture, trace_level, gpu_trace, save_trace_result_path, args.cuda_runtime_driver_time_adjustment == "true")
mlms.load_agent(task, agent, model_name, security_check, config, user)
print(f"{agent}-agent is loaded with {model_name} model\n")
mlms.load_dataset(dataset_name, batch_size, None, security_check)
print(f"{dataset_name} dataset is loaded\n")
print(f"prediction starts")
outputs = mlms.predict(num_warmup)
print("prediction is done\n")
print_outputs(outputs, task, detailed, save_output_path)
mlms.Close()
def load_config(config_file_flag, config_file_path):
if config_file_flag == "true":
try:
with open(config_file_path, 'r') as f:
config = json.load(f)
print(f"config file {config_file_path} is loaded")
return config
except (json.JSONDecodeError, FileNotFoundError) as e:
print(f"config file {config_file_path} is not loaded: {e}")
return None
def print_outputs(outputs, task, detailed, save_output_path):
print("outputs are as follows:")
if task in ["image_classification", "sentiment_analysis"]:
if detailed:
print(outputs)
else:
print(np.argmax(outputs, axis=1))
elif task == "image_object_detection":
print("image_object_detection")
print(len(outputs))
for output in outputs:
print(f"{len(output[0][0])} {len(output[1][0])} {len(output[2][0])}")
elif task in ["image_semantic_segmentation", "depth_estimation"]:
for index, output in enumerate(outputs):
print(f"outputs[{index}] width: {len(output)}, height: {len(output[0])}")
elif task in ["image_enhancement", "image_synthesis"]:
for index, output in enumerate(outputs):
print(f"outputs[{index}] width: {len(output)}, height: {len(output[0])}, channel: {len(output[0][0])}")
elif task == "image_instance_segmentation":
for index, output in enumerate(outputs):
print(f"outputs[{index}] width: {len(output[0][0])}, height: {len(output[0][0][0])}")
for i, label in enumerate(output[1]):
print(f"labels[{i}]: {label}", end=' ')
print(f"\noutputs[{index}] masks unique number list: {np.unique(output[0]).tolist()}\n")
elif task == "image_instance_segmentation_raw":
print(len(outputs))
print(len(outputs[0]))
elif task in ["speech_synthesis", "audio_generation"]:
for index, output in enumerate(outputs):
print(f"outputs[{index}] length: {len(output)}")
else:
for index, output in enumerate(outputs):
print(f"outputs[{index}]: {output}")
if task in ["text_to_text"] and save_output_path:
with open(save_output_path, 'w') as f:
json.dump(outputs, f, indent=4)
print(f"result is saved in {save_output_path}")
def run_non_standalone(args):
db_config, mq_config = load_configs(args)
db_conn = DatabaseConnection(
db_config['host'], db_config['name'],
db_config['user'], db_config['password'],
int(db_config['port'])
)
mq_handler = MessageQueueHandler(
mq_config['host'], int(mq_config['port']),
mq_config['user'], mq_config['password'],
mq_config['name']
)
def callback(ch, method, properties, body):
process_message(db_conn, body, properties, args.agent)
try:
mq_handler.start_consuming(callback)
finally:
db_conn.close()
mq_handler.close()
def load_configs(args):
if args.env_file == 'false':
db_config = {
'name': args.db_dbname,
'host': args.db_host,
'port': args.db_port,
'user': args.db_user,
'password': args.db_password
}
mq_config = {
'name': args.mq_name,
'host': args.mq_host,
'port': args.mq_port,
'user': args.mq_user,
'password': args.mq_password
}
else:
db_config = {
'name': os.environ['DB_DBNAME'],
'host': os.environ['DB_HOST'],
'port': os.environ['DB_PORT'],
'user': os.environ['DB_USER'],
'password': os.environ['DB_PASSWORD']
}
mq_config = {
# 'name': f'agent-{args.agent}-amd64',
'name': f'agent-{os.environ['MQ_USER']}-{os.environ['AGENT_VERSION']}-{os.environ["ARCHITECTURE"]}',
'host': os.environ['MQ_HOST'],
'port': os.environ['MQ_PORT'],
'user': os.environ['MQ_USER'],
'password': os.environ['MQ_PASSWORD']
}
return db_config, mq_config
def main():
args = parse_args()
if args.standalone == 'true':
run_standalone(args)
else:
run_non_standalone(args)
if __name__ == "__main__":
main()