Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 103 additions & 8 deletions tests/test_creation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
import math
import os
import trsfile
import time
import shutil
import tempfile
import time
import unittest
import math
import shutil

from trsfile import Trace, SampleCoding, Header, TracePadding
from trsfile.parametermap import TraceParameterMap, TraceParameterDefinitionMap, TraceSetParameterMap, RawTraceData
import numpy

import trsfile
from trsfile import Header, SampleCoding, Trace, TracePadding
from trsfile.parametermap import (
RawTraceData,
TraceParameterDefinitionMap,
TraceParameterMap,
TraceSetParameterMap,
)
from trsfile.standardparameters import StandardTraceSetParameters
from trsfile.traceparameter import ByteArrayParameter, TraceParameterDefinition, ParameterType, StringParameter, \
IntegerArrayParameter, BooleanArrayParameter, FloatArrayParameter
from trsfile.traceparameter import (
BooleanArrayParameter,
ByteArrayParameter,
FloatArrayParameter,
IntegerArrayParameter,
ParameterType,
StringParameter,
TraceParameterDefinition,
)


def get_sample(x):
Expand Down Expand Up @@ -526,6 +540,87 @@ def test_padding(self):
# Test that this is indeed not zero
self.assertNotEqual(trs_trace[-i - 1], 0)

def __metadata_test(
self, headers: dict[Header, object], expect_overflow: bool = False
) -> None:
key = "MY_IMPORTANT_ATTRIBUTE"
headers[Header.TRACE_PARAMETER_DEFINITIONS] = TraceParameterDefinitionMap(
{key: TraceParameterDefinition(ParameterType.BYTE, length=4, offset=0)}
)
sample_count = 1
trace_count = 1
if expect_overflow:
ctx: object = self.assertRaises(OverflowError).__enter__()
else:
ctx = None
try:
with trsfile.open(
self.tmp_path, "w", headers=headers, padding_mode=TracePadding.AUTO
) as fp:
fp.extend(
[
Trace(
SampleCoding.FLOAT,
[0] * sample_count,
TraceParameterMap(
{
key: ByteArrayParameter(
numpy.array([1, 2, 3, 4], dtype=numpy.uint8)
)
}
),
)
for i in range(0, trace_count)
]
)
except BaseException as ex:
if ctx:
ctx.__exit__(type(ex), ex, None)
else:
raise
else:
if ctx:
ctx.__exit__(None, None, None)

def test_negative_signed_integers_in_metadata(self):
"""Ensure that negative values are two's-complement (and 32-bits)."""
self.__metadata_test(
{
Header.DESCRIPTION: "hello there, general kenobi",
Header.ACQUISITION_DEVICE_ID: "Super Ultra Mega Scope",
Header.ACQUISITION_INPUT_IMPEDANCE: -5.0,
Header.OFFSET_X: -10000,
Header.EXTERNAL_CLOCK_MULTIPLIER: -1,
Header.EXTERNAL_CLOCK_PHASE_SHIFT: -35000,
}
)

def test_datalength_is_uint16(self):
"""Ensure shorts are unsigned"""
self.__metadata_test(
{
# (1<<16) - 1 fits in a uint16
Header.LENGTH_DATA: (1 << 16) - 1,
}
)

def test_throws_overflow_if_datalength_does_not_fit(self):
self.__metadata_test(
{
# (1<<16) does NOT fit in a uint16
Header.LENGTH_DATA: (1 << 16),
},
expect_overflow=True,
)

def test_titlespace_is_uint8(self):
"""Ensure bytes are unsigned"""
self.__metadata_test(
{
# (1<<8)-1 fits in a uint18
Header.TITLE_SPACE: (1 << 8) - 1,
}
)

if __name__ == '__main__':
unittest.main()
Expand Down
32 changes: 19 additions & 13 deletions trsfile/engine/trs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
import sys
import copy
import mmap
import os
import struct
from typing import List, Union, Dict, Any, Optional
import sys
from io import BytesIO
from typing import Any, Dict, List, Optional, Union

import numpy
import copy

from io import BytesIO
from trsfile.trace import Trace
from trsfile.traceparameter import ByteArrayParameter
from trsfile.common import Header, SampleCoding, TracePadding
from trsfile.engine.engine import Engine
from trsfile.parametermap import TraceSetParameterMap, TraceParameterDefinitionMap, TraceParameterMap
from trsfile.parametermap import (
TraceParameterDefinitionMap,
TraceParameterMap,
TraceSetParameterMap,
)
from trsfile.trace import Trace
from trsfile.traceparameter import ByteArrayParameter

ASCII_LESS_THAN = 0x3C

Expand Down Expand Up @@ -512,15 +516,16 @@ def __write_headers(self, headers: Optional[Dict[Header, Any]] = None):

# Obtain the tag value
if header.type is int:
tag_value = b'\xff' * header.length if value is None else value.to_bytes(header.length, 'little')
# Note the delicacy with the signedness here.
tag_value = b'\xff' * header.length if value is None else value.to_bytes(header.length, byteorder='little', signed=header.length >= 4)
elif header.type is float:
tag_value = struct.pack('<f', 0.0 if value is None else value)
elif header.type is bool:
tag_value = struct.pack('<?', 0 if value is None else value)
elif header.type is str:
tag_value = value.encode('utf-8')
elif header.type is SampleCoding:
tag_value = b'\xff' if value is None else value.value.to_bytes(1, 'little')
tag_value = b'\xff' if value is None else value.value.to_bytes(1, byteorder='little', signed=False)
elif header.type is bytes:
tag_value = value
elif header.type is TraceSetParameterMap:
Expand Down Expand Up @@ -551,7 +556,7 @@ def __write_headers(self, headers: Optional[Dict[Header, Any]] = None):
tag = [header.value]
if tag_length >= 0x80:
tag_length_length = (tag_length.bit_length() // 8) + (1 if tag_length.bit_length() % 8 > 0 else 0)
tag += bytes([0x80 | tag_length_length]) + tag_length.to_bytes(tag_length_length, 'little')
tag += bytes([0x80 | tag_length_length]) + tag_length.to_bytes(tag_length_length, byteorder='little', signed=tag_length_length >= 4)
else:
tag += [tag_length]
tag += tag_value
Expand Down Expand Up @@ -599,7 +604,8 @@ def __read_headers(self) -> None:
tag_length = self.handle.read(1)[0]

if (tag_length & 0x80) != 0:
tag_length = int.from_bytes(self.handle.read(tag_length & 0x7F), 'little')
tag_length &= 0x7F
tag_length = int.from_bytes(self.handle.read(tag_length), byteorder='little', signed=tag_length >= 4)
if tag_length == 0 and tag != Header.TRACE_BLOCK.value:
continue

Expand All @@ -612,7 +618,7 @@ def __read_headers(self) -> None:
if Header.has_value(tag):
header = Header(tag)
if header.type is int:
tag_value = int.from_bytes(tag_value, 'little')
tag_value = int.from_bytes(tag_value, byteorder='little', signed=tag_length >= 4)
elif header.type is float:
tag_value, = struct.unpack('<f', tag_value)
elif header.type is bool:
Expand Down
Loading