From eaa9fb8a7da7ff4c32e4cf5a96f6a0ddf680c118 Mon Sep 17 00:00:00 2001 From: Raoul Wols Date: Thu, 4 Jun 2026 13:48:06 +0200 Subject: [PATCH] Fix signed vs unsigned bug when (de)serializing 32-bit integers The rule is: * 1-byte integers are unsigned. * 2-byte integers (shorts) are unsigned. * 4-byte integers are **signed**. --- tests/test_creation.py | 111 ++++++++++++++++++++++++++++++++++++++--- trsfile/engine/trs.py | 32 +++++++----- 2 files changed, 122 insertions(+), 21 deletions(-) diff --git a/tests/test_creation.py b/tests/test_creation.py index 3e188e3..eeb9625 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -1,16 +1,30 @@ +import math import os -import trsfile -import time +import shutil import tempfile +import time import unittest -import math -import shutil -from trsfile import Trace, SampleCoding, Header, TracePadding -from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap, RawTraceData +import numpy + +import trsfile +from trsfile import Header, SampleCoding, Trace, TracePadding +from trsfile.parametermap import ( + RawTraceData, + TraceParameterDefinitionMap, + TraceParameterMap, + TraceSetParameterMap, +) from trsfile.standardparameters import StandardTraceSetParameters -from trsfile.traceparameter import ByteArrayParameter, TraceParameterDefinition, ParameterType, StringParameter, \ - IntegerArrayParameter, BooleanArrayParameter, FloatArrayParameter +from trsfile.traceparameter import ( + BooleanArrayParameter, + ByteArrayParameter, + FloatArrayParameter, + IntegerArrayParameter, + ParameterType, + StringParameter, + TraceParameterDefinition, +) def get_sample(x): @@ -526,6 +540,87 @@ def test_padding(self): # Test that this is indeed not zero self.assertNotEqual(trs_trace[-i - 1], 0) + def __metadata_test( + self, headers: dict[Header, object], expect_overflow: bool = False + ) -> None: + key = "MY_IMPORTANT_ATTRIBUTE" + headers[Header.TRACE_PARAMETER_DEFINITIONS] = TraceParameterDefinitionMap( + {key: TraceParameterDefinition(ParameterType.BYTE, length=4, offset=0)} + ) + sample_count = 1 + trace_count = 1 + if expect_overflow: + ctx: object = self.assertRaises(OverflowError).__enter__() + else: + ctx = None + try: + with trsfile.open( + self.tmp_path, "w", headers=headers, padding_mode=TracePadding.AUTO + ) as fp: + fp.extend( + [ + Trace( + SampleCoding.FLOAT, + [0] * sample_count, + TraceParameterMap( + { + key: ByteArrayParameter( + numpy.array([1, 2, 3, 4], dtype=numpy.uint8) + ) + } + ), + ) + for i in range(0, trace_count) + ] + ) + except BaseException as ex: + if ctx: + ctx.__exit__(type(ex), ex, None) + else: + raise + else: + if ctx: + ctx.__exit__(None, None, None) + + def test_negative_signed_integers_in_metadata(self): + """Ensure that negative values are two's-complement (and 32-bits).""" + self.__metadata_test( + { + Header.DESCRIPTION: "hello there, general kenobi", + Header.ACQUISITION_DEVICE_ID: "Super Ultra Mega Scope", + Header.ACQUISITION_INPUT_IMPEDANCE: -5.0, + Header.OFFSET_X: -10000, + Header.EXTERNAL_CLOCK_MULTIPLIER: -1, + Header.EXTERNAL_CLOCK_PHASE_SHIFT: -35000, + } + ) + + def test_datalength_is_uint16(self): + """Ensure shorts are unsigned""" + self.__metadata_test( + { + # (1<<16) - 1 fits in a uint16 + Header.LENGTH_DATA: (1 << 16) - 1, + } + ) + + def test_throws_overflow_if_datalength_does_not_fit(self): + self.__metadata_test( + { + # (1<<16) does NOT fit in a uint16 + Header.LENGTH_DATA: (1 << 16), + }, + expect_overflow=True, + ) + + def test_titlespace_is_uint8(self): + """Ensure bytes are unsigned""" + self.__metadata_test( + { + # (1<<8)-1 fits in a uint18 + Header.TITLE_SPACE: (1 << 8) - 1, + } + ) if __name__ == '__main__': unittest.main() diff --git a/trsfile/engine/trs.py b/trsfile/engine/trs.py index 55f6ff7..f4aa14f 100644 --- a/trsfile/engine/trs.py +++ b/trsfile/engine/trs.py @@ -1,18 +1,22 @@ -import os -import sys +import copy import mmap +import os import struct -from typing import List, Union, Dict, Any, Optional +import sys +from io import BytesIO +from typing import Any, Dict, List, Optional, Union import numpy -import copy -from io import BytesIO -from trsfile.trace import Trace -from trsfile.traceparameter import ByteArrayParameter from trsfile.common import Header, SampleCoding, TracePadding from trsfile.engine.engine import Engine -from trsfile.parametermap import TraceSetParameterMap, TraceParameterDefinitionMap, TraceParameterMap +from trsfile.parametermap import ( + TraceParameterDefinitionMap, + TraceParameterMap, + TraceSetParameterMap, +) +from trsfile.trace import Trace +from trsfile.traceparameter import ByteArrayParameter ASCII_LESS_THAN = 0x3C @@ -512,7 +516,8 @@ def __write_headers(self, headers: Optional[Dict[Header, Any]] = None): # Obtain the tag value if header.type is int: - tag_value = b'\xff' * header.length if value is None else value.to_bytes(header.length, 'little') + # Note the delicacy with the signedness here. + tag_value = b'\xff' * header.length if value is None else value.to_bytes(header.length, byteorder='little', signed=header.length >= 4) elif header.type is float: tag_value = struct.pack('= 0x80: tag_length_length = (tag_length.bit_length() // 8) + (1 if tag_length.bit_length() % 8 > 0 else 0) - tag += bytes([0x80 | tag_length_length]) + tag_length.to_bytes(tag_length_length, 'little') + tag += bytes([0x80 | tag_length_length]) + tag_length.to_bytes(tag_length_length, byteorder='little', signed=tag_length_length >= 4) else: tag += [tag_length] tag += tag_value @@ -599,7 +604,8 @@ def __read_headers(self) -> None: tag_length = self.handle.read(1)[0] if (tag_length & 0x80) != 0: - tag_length = int.from_bytes(self.handle.read(tag_length & 0x7F), 'little') + tag_length &= 0x7F + tag_length = int.from_bytes(self.handle.read(tag_length), byteorder='little', signed=tag_length >= 4) if tag_length == 0 and tag != Header.TRACE_BLOCK.value: continue @@ -612,7 +618,7 @@ def __read_headers(self) -> None: if Header.has_value(tag): header = Header(tag) if header.type is int: - tag_value = int.from_bytes(tag_value, 'little') + tag_value = int.from_bytes(tag_value, byteorder='little', signed=tag_length >= 4) elif header.type is float: tag_value, = struct.unpack('