Skip to content

Commit 1346b94

Browse files
committed
Added convert_pth_to_tfjs.py
1 parent 9ac5b4f commit 1346b94

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
Convert PyTorch DeepLabV3 model (.pth) to TensorFlow.js format.
3+
4+
This script converts the model through the following pipeline:
5+
PyTorch (.pth) -> ONNX -> TensorFlow SavedModel -> TensorFlow.js
6+
7+
Requirements:
8+
pip install torch torchvision onnx onnx2tf tensorflowjs tensorflow
9+
10+
Usage:
11+
python convert_pth_to_tfjs.py
12+
"""
13+
14+
import os
15+
import sys
16+
import shutil
17+
import subprocess
18+
19+
# === Windows Fix for TensorFlow.js ===
20+
# tensorflowjs tries to import tensorflow_decision_forests which is not available on Windows.
21+
# We mock it before importing tensorflowjs to prevent the crash.
22+
try:
23+
import tensorflow_decision_forests
24+
except ImportError:
25+
class MockTFDF:
26+
pass
27+
sys.modules['tensorflow_decision_forests'] = MockTFDF()
28+
# =====================================
29+
30+
def check_dependencies():
31+
"""Check if required packages are installed."""
32+
required = ['torch', 'torchvision', 'onnx', 'onnx2tf', 'tensorflowjs', 'tensorflow']
33+
missing = []
34+
35+
for pkg in required:
36+
try:
37+
__import__(pkg)
38+
except ImportError:
39+
missing.append(pkg)
40+
41+
if missing:
42+
print(f"Missing packages: {', '.join(missing)}")
43+
print(f"Install with: pip install {' '.join(missing)}")
44+
return False
45+
return True
46+
47+
48+
def export_to_onnx(pth_path, onnx_path, input_size=384):
49+
"""Export PyTorch model to ONNX format."""
50+
import torch
51+
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large
52+
53+
print(f"Loading PyTorch model from {pth_path}...")
54+
55+
# Create model architecture
56+
model = deeplabv3_mobilenet_v3_large(num_classes=2)
57+
58+
# Load weights
59+
state_dict = torch.load(pth_path, map_location='cpu')
60+
# Use strict=False to ignore auxiliary classifier weights if present (we only need inference)
61+
model.load_state_dict(state_dict, strict=False)
62+
model.eval()
63+
64+
# Create dummy input
65+
dummy_input = torch.randn(1, 3, input_size, input_size)
66+
67+
print(f"Exporting to ONNX: {onnx_path}...")
68+
69+
# Export to ONNX
70+
torch.onnx.export(
71+
model,
72+
dummy_input,
73+
onnx_path,
74+
input_names=['input'],
75+
output_names=['output'],
76+
dynamic_axes=None, # Fixed shape for better compatibility
77+
opset_version=12,
78+
do_constant_folding=True
79+
)
80+
81+
print("ONNX export complete!")
82+
return True
83+
84+
85+
def convert_onnx_to_tf(onnx_path, tf_output_dir, input_size=384):
86+
"""Convert ONNX model to TensorFlow SavedModel using onnx2tf."""
87+
print(f"Converting ONNX to TensorFlow SavedModel...")
88+
89+
cmd = [
90+
'onnx2tf',
91+
'-i', onnx_path,
92+
'-o', tf_output_dir,
93+
'-ois', f'input:1,3,{input_size},{input_size}' # Fix input shape
94+
]
95+
96+
result = subprocess.run(cmd, capture_output=True, text=True)
97+
98+
if result.returncode != 0:
99+
print(f"Error: {result.stderr}")
100+
return False
101+
102+
print("TensorFlow SavedModel conversion complete!")
103+
return True
104+
105+
106+
def convert_tf_to_tfjs(tf_saved_model_dir, tfjs_output_dir):
107+
"""Convert TensorFlow SavedModel to TensorFlow.js format."""
108+
import tensorflow as tf
109+
110+
print(f"Converting TensorFlow SavedModel to TensorFlow.js...")
111+
112+
# First, we need to add a signature to the SavedModel
113+
# This is required for tensorflowjs conversion
114+
115+
print("Loading SavedModel and adding signature...")
116+
117+
# Load the model
118+
loaded = tf.saved_model.load(tf_saved_model_dir)
119+
120+
# Get the concrete function
121+
if hasattr(loaded, 'signatures') and 'serving_default' in loaded.signatures:
122+
print("SavedModel already has serving_default signature")
123+
else:
124+
# Need to wrap and re-save with signature
125+
print("Adding serving_default signature...")
126+
127+
# Find the inference function
128+
infer = None
129+
if hasattr(loaded, '__call__'):
130+
infer = loaded.__call__
131+
elif hasattr(loaded, 'serve'):
132+
infer = loaded.serve
133+
134+
if infer is None:
135+
# Try to get from signatures
136+
for key in dir(loaded):
137+
attr = getattr(loaded, key)
138+
if callable(attr) and not key.startswith('_'):
139+
infer = attr
140+
break
141+
142+
if infer is None:
143+
print("Could not find inference function, trying direct conversion...")
144+
else:
145+
# Create a wrapper module
146+
class WrapperModule(tf.Module):
147+
def __init__(self, model):
148+
super().__init__()
149+
self.model = model
150+
151+
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 384, 384, 3], dtype=tf.float32)])
152+
def serving_default(self, x):
153+
return self.model(x)
154+
155+
wrapper = WrapperModule(loaded)
156+
157+
# Re-save with signature
158+
wrapped_dir = tf_saved_model_dir + '_wrapped'
159+
tf.saved_model.save(
160+
wrapper,
161+
wrapped_dir,
162+
signatures={'serving_default': wrapper.serving_default}
163+
)
164+
tf_saved_model_dir = wrapped_dir
165+
166+
# Now convert to TensorFlow.js
167+
print("Running tensorflowjs converter...")
168+
169+
import tensorflowjs as tfjs
170+
171+
try:
172+
# Try with signature_def (common in newer versions)
173+
tfjs.converters.convert_tf_saved_model(
174+
tf_saved_model_dir,
175+
tfjs_output_dir,
176+
signature_def='serving_default'
177+
)
178+
print(f"TensorFlow.js model saved to: {tfjs_output_dir}")
179+
return True
180+
except TypeError:
181+
try:
182+
# Try without signature arg (defaults to serving_default)
183+
tfjs.converters.convert_tf_saved_model(
184+
tf_saved_model_dir,
185+
tfjs_output_dir
186+
)
187+
print(f"TensorFlow.js model saved to: {tfjs_output_dir}")
188+
return True
189+
except Exception as e:
190+
print(f"Second attempt failed: {e}")
191+
raise e
192+
except Exception as e:
193+
print(f"Conversion error: {e}")
194+
195+
# Try alternative method using command line
196+
print("Trying command line converter...")
197+
cmd = [
198+
sys.executable, '-m', 'tensorflowjs.converters.converter',
199+
'--input_format=tf_saved_model',
200+
'--output_format=tfjs_graph_model',
201+
'--signature_name=serving_default',
202+
tf_saved_model_dir,
203+
tfjs_output_dir
204+
]
205+
206+
result = subprocess.run(cmd, capture_output=True, text=True)
207+
if result.returncode == 0:
208+
print(f"TensorFlow.js model saved to: {tfjs_output_dir}")
209+
return True
210+
else:
211+
print(f"Command line conversion failed: {result.stderr}")
212+
return False
213+
214+
215+
def main():
216+
# Configuration
217+
PTH_PATH = 'model_mbv3_iou_mix_2C049.pth'
218+
ONNX_PATH = 'temp_model.onnx'
219+
TF_SAVED_MODEL_DIR = 'tf_saved_model'
220+
TFJS_OUTPUT_DIR = 'web_app_tfjs/tfjs_model'
221+
INPUT_SIZE = 384
222+
223+
print("=" * 50)
224+
print("PyTorch to TensorFlow.js Converter")
225+
print("=" * 50)
226+
227+
# Check dependencies
228+
if not check_dependencies():
229+
sys.exit(1)
230+
231+
# Force cleanup of previous run artifacts to ensure fresh conversion
232+
if os.path.exists(ONNX_PATH):
233+
os.remove(ONNX_PATH)
234+
if os.path.exists(TF_SAVED_MODEL_DIR):
235+
shutil.rmtree(TF_SAVED_MODEL_DIR)
236+
if os.path.exists(TF_SAVED_MODEL_DIR + '_wrapped'):
237+
shutil.rmtree(TF_SAVED_MODEL_DIR + '_wrapped')
238+
239+
# Step 1: PyTorch -> ONNX
240+
if not os.path.exists(ONNX_PATH):
241+
if not export_to_onnx(PTH_PATH, ONNX_PATH, INPUT_SIZE):
242+
print("Failed to export ONNX model")
243+
sys.exit(1)
244+
else:
245+
print(f"ONNX model already exists: {ONNX_PATH}")
246+
247+
# Step 2: ONNX -> TensorFlow SavedModel
248+
if not os.path.exists(TF_SAVED_MODEL_DIR):
249+
if not convert_onnx_to_tf(ONNX_PATH, TF_SAVED_MODEL_DIR, INPUT_SIZE):
250+
print("Failed to convert to TensorFlow")
251+
sys.exit(1)
252+
else:
253+
print(f"TensorFlow SavedModel already exists: {TF_SAVED_MODEL_DIR}")
254+
255+
# Step 3: TensorFlow SavedModel -> TensorFlow.js
256+
os.makedirs(TFJS_OUTPUT_DIR, exist_ok=True)
257+
if not convert_tf_to_tfjs(TF_SAVED_MODEL_DIR, TFJS_OUTPUT_DIR):
258+
print("Failed to convert to TensorFlow.js")
259+
sys.exit(1)
260+
261+
# Cleanup temporary files
262+
if os.path.exists(ONNX_PATH):
263+
os.remove(ONNX_PATH)
264+
print(f"Cleaned up: {ONNX_PATH}")
265+
266+
print("=" * 50)
267+
print("Conversion complete!")
268+
print(f"TensorFlow.js model: {TFJS_OUTPUT_DIR}/model.json")
269+
print("=" * 50)
270+
271+
272+
if __name__ == '__main__':
273+
main()

0 commit comments

Comments
 (0)