Skip to content

Converting decoder to TensorRT #357

@Amarnath1906

Description

@Amarnath1906

I am trying to convert HIFI GAN decoder to tensorrt.
Here is the script I am trying to use,
`

import tensorrt as trt
import os
  
def build_engine(onnx_file_path):

  logger = trt.Logger(trt.Logger.VERBOSE)
  builder = trt.Builder(logger)

  network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  network = builder.create_network(network_flags)
  parser = trt.OnnxParser(network, logger)

  if not os.path.exists(onnx_file_path):
      print(f"ONNX file {onnx_file_path} not found.")
      return None

  print(f"Loading ONNX file: {onnx_file_path}")
  with open(onnx_file_path, "rb") as model:
      if not parser.parse(model.read()):
          for i in range(parser.num_errors):
              print(parser.get_error(i))
          return None

  config = builder.create_builder_config()

  # L40S has large memory → allow bigger workspace
  config.set_memory_pool_limit(
      trt.MemoryPoolType.WORKSPACE,
      4 * 1024 * 1024 * 1024   # 4GB workspace (safe on 48GB L40S)
  )

  # FP16 (always good)
  if builder.platform_has_fast_fp16:
      config.set_flag(trt.BuilderFlag.FP16)
      print("FP16 mode enabled (L40S optimized).")

  
  # ───────────────────────────────────────────────
  # Dynamic Shapes (same as your model)
  # ───────────────────────────────────────────────
  profile = builder.create_optimization_profile()

  profile.set_shape("ASR",
      (1, 512, 28),
      (1, 512, 100),
      (1, 512, 1106)
  )

  profile.set_shape("F0_PRED",
      (1, 56),
      (1, 200),
      (1, 2212)
  )

  profile.set_shape("N_PRED",
      (1, 56),
      (1, 200),
      (1, 2212)
  )

  profile.set_shape("REF",
      (1, 128),
      (1, 128),
      (1, 128)
  )

  config.add_optimization_profile(profile)

  print("Building TensorRT engine on L40S...")

  serialized_engine = builder.build_serialized_network(network, config)

  if serialized_engine:
      with open("sample.engine", "wb") as f:
          f.write(serialized_engine)
      print("Success: Engine saved as sample.engine")
      return serialized_engine
  else:
      print("Error: Build failed.")
      return None

if name == "main":
ONNX_PATH = "decoder_v2.onnx"
build_engine(ONNX_PATH)

`

And I am getting the following error,
[01/16/2026-10:05:01] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: /Concat: axis 2 dimensions must be equal for concatenation on axis 1. Dimensions are seq_len and (+ (CEIL_DIV (+ seq_len -2) 2) 1). Condition '==' violated: 100 != 50.) Error: Build failed.

This is the Decoder function I have,

`

class Decoder(nn.Module):
      def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80, 
                  resblock_kernel_sizes = [3,7,11],
                  upsample_rates = [10,5,3,2],
                  upsample_initial_channel=512,
                  resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
                  upsample_kernel_sizes=[20,10,6,4]):
          super().__init__()
      
      self.decode = nn.ModuleList()
      
      self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
      
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))

      self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
      
      self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
      
      self.asr_res = nn.Sequential(
          weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
      )
      
      
      self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
  
  def match(self,t, ref):
      if t.shape[-1] != ref.shape[-1]:
          t = F.interpolate(t, size=ref.shape[-1], mode="nearest")
      return t


      
  def forward(self, asr, F0_curve, N, s):
      
      if self.training:
          downlist = [0, 3, 7]
          F0_down = downlist[random.randint(0, 2)]
          downlist = [0, 3, 7, 15]
          N_down = downlist[random.randint(0, 3)]
          if F0_down:
              F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
          if N_down:
              N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1)  / N_down

      
      F0 = self.F0_conv(F0_curve.unsqueeze(1))
      N = self.N_conv(N.unsqueeze(1))
      
      x = torch.cat([asr, F0, N], axis=1)
      x = self.encode(x, s)
  
      
      asr_res = self.asr_res(asr)
       
      res = True
      for block in self.decode:
          if res:
              
              x = torch.cat([x, asr_res, F0, N], axis=1)
          x = block(x, s)
          if block.upsample_type != "none":
              res = False
              
      x = self.generator(x, s, F0_curve)
      return x

`

This is how i am converting the decoder to onnx,

`

def export_decoder_dynamic(model,repo_path):
bmodel = model
decoder = bmodel.decoder.eval().cuda()

    model_path = "decoder_v2.onnx"
  
  # os.makedirs(os.path.dirname(model_path), exist_ok=True)

  batch = 2
  seq_len = 40

  asr = torch.randn(batch, 512, seq_len, dtype=torch.float32).cuda()
  F0_pred = torch.randn(batch, seq_len*2, dtype=torch.float32).cuda()
  # F0 =  torch.randn(batch, 1, seq_len, dtype=torch.float32)
  N_pred = torch.randn(batch, seq_len*2, dtype=torch.float32).cuda()
  # N_pred = torch.randn(batch, 1, seq_len, dtype=torch.float32)
  ref = torch.randn(batch, 128, dtype=torch.float32).cuda()             # [B, S]

  print(asr.shape,F0_pred.shape,N_pred.shape,ref.shape)

  try:
      torch.onnx.export(
          decoder,
          (asr,F0_pred,N_pred,ref),
          model_path,
          input_names=["ASR","F0_PRED","N_PRED","REF"],
          output_names=["AUDIO_OUT"],
          dynamic_axes={
              "ASR": {0: "batch", 2: "seq_len"},       # input [B, C, T]
              "F0_PRED": {0: "batch", 1: "seq_len"},       # input [B, C, T]
              "N_PRED": {0: "batch", 1: "seq_len"},       # input [B, C, T]
              "REF": {0: "batch"},       # input [B, C, T]
              "AUDIO_OUT": {0: "batch", 2: "seq_len"},   # output [B, C, T]
          },
          opset_version=17,
          do_constant_folding=True,
          verbose=False

      )

      print("decoder exported:", model_path)
  except Exception as e:
      print(traceback.print_exc())
      print(f"Failed to export decoder: {e}")


  try:
      model = onnx.load(model_path)
      model_with_shapes = onnx.shape_inference.infer_shapes(model)
      onnx.save(model_with_shapes, model_path)

      # Create ONNX Runtime session
      session = ort.InferenceSession(model_path)

      # Check the input names
      input_names = [inp.name for inp in session.get_inputs()]
      print("ONNX input names:", input_names)

      # Check output names
      output_names = [out.name for out in session.get_outputs()]
      print("ONNX output names:", output_names)

      onnx_out = session.run(None,
          {"ASR": asr.cpu().numpy(), 
           "F0_PRED": F0_pred.cpu().numpy(),
           "N_PRED": N_pred.cpu().numpy(),
           "REF": ref.cpu().numpy()
           }
      )[0]
      
      print("ONNX output shape:", onnx_out.shape)
      
  except Exception as e:
      print(traceback.print_exc())
      print(f"Failed to export decoder: {e}")
      return

`

I am using nvcr.io/nvidia/tensorrt:25.10-py3 container to convert this and here are few more specs about the GPU,
`

ok-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA L40S Off | 00000000:01:01.0 Off | 0 |
| N/A 39C P0 81W / 350W | 23923MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

`

CUDA Verison: 12.6.

Let me know where I am going wrong.

Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions