laforge submitted this change.
construct: allow stripping of leading zeros with StripHeaderAdapter
This patch adds a new class StripHeaderAdapter. This adapter class can
be used to strip leading zeros in the same way that StripTrailerAdapter
strips trailing zeros.
Related: OS#6679
Change-Id: I1a9fff17abbbef0c5f6d45f58198debfa12e78b6
---
M src/osmocom/construct.py
M tests/test_construct.py
2 files changed, 89 insertions(+), 0 deletions(-)
diff --git a/src/osmocom/construct.py b/src/osmocom/construct.py
index d4284d9..7a2851a 100644
--- a/src/osmocom/construct.py
+++ b/src/osmocom/construct.py
@@ -451,6 +451,48 @@
else:
return obj_step_aligned;
+class StripHeaderAdapter(Adapter):
+ """
+ Encoder removes all leading bytes matching the default_value
+ Decoder pads input data up to total_length with default_value
+
+ In case the encoding restricts the length of the result to specific values, the API user may set those restrictions
+ using the steps parameter. (e.g. encoded result must be either 1 or 3 byte long, steps would be set to [1,3])
+
+ This is used in constellations like "FlagsEnum(StripHeaderAdapter(GreedyBytes, 3), ..."
+ where you have a bit-mask that may have 1, 2 or 3 bytes, depending on whether or not any
+ of the MSBs are actually set.
+ """
+ def __init__(self, subcon, total_length:int, default_value=b'\x00', min_len=1, steps:typing.List[int]=[]):
+ super().__init__(subcon)
+ assert len(default_value) == 1
+ self.total_length = total_length
+ self.default_value = default_value
+ self.min_len = min_len
+ self.steps = steps
+
+ def _decode(self, obj, context, path):
+ assert isinstance(obj, bytes)
+ # pad with suppressed/missing bytes
+ if len(obj) < self.total_length:
+ obj = self.default_value * (self.total_length - len(obj)) + obj
+ return int.from_bytes(obj, 'big')
+
+ def _encode(self, obj, context, path):
+ assert isinstance(obj, int)
+ obj = obj.to_bytes(self.total_length, 'big')
+ # remove trailing bytes if they are zero
+
+ obj_step_aligned = obj
+ while len(obj) > self.min_len and obj[0] == self.default_value[0]:
+ obj = obj[1:]
+ if len(obj) in self.steps:
+ obj_step_aligned = obj
+
+ if self.steps == []:
+ return obj
+ else:
+ return obj_step_aligned;
def filter_dict(d, exclude_prefix='_'):
"""filter the input dict to ensure no keys starting with 'exclude_prefix' remain."""
diff --git a/tests/test_construct.py b/tests/test_construct.py
index e4234db..94ae71f 100755
--- a/tests/test_construct.py
+++ b/tests/test_construct.py
@@ -101,6 +101,9 @@
final_application=0x0200, global_service=0x0100,
receipt_generation=0x80, ciphered_load_file_data_block=0x40,
contactless_activation=0x20, contactless_self_activation=0x10)
+ IntegerSteps = StripTrailerAdapter(GreedyBytes, 4, steps = [2,4])
+ Integer = StripTrailerAdapter(GreedyBytes, 4)
+
examples = ['00', '80', '8040', '400010']
def test_encdec(self):
for e in self.examples:
@@ -108,6 +111,25 @@
reenc = self.Privileges.build(dec)
self.assertEqual(e, b2h(reenc))
+ def test_encdec_integer(self):
+ enc = self.IntegerSteps.build(0x10000000)
+ self.assertEqual(b2h(enc), '1000')
+ enc = self.IntegerSteps.build(0x10200000)
+ self.assertEqual(b2h(enc), '1020')
+ enc = self.IntegerSteps.build(0x10203000)
+ self.assertEqual(b2h(enc), '10203000')
+ enc = self.IntegerSteps.build(0x10203040)
+ self.assertEqual(b2h(enc), '10203040')
+
+ enc = self.Integer.build(0x10000000)
+ self.assertEqual(b2h(enc), '10')
+ enc = self.Integer.build(0x10200000)
+ self.assertEqual(b2h(enc), '1020')
+ enc = self.Integer.build(0x10203000)
+ self.assertEqual(b2h(enc), '102030')
+ enc = self.Integer.build(0x10203040)
+ self.assertEqual(b2h(enc), '10203040')
+
def test_enc(self):
enc = self.Privileges.build({'dap_verification' : True})
self.assertEqual(b2h(enc), '40')
@@ -124,6 +146,31 @@
self.assertEqual(b2h(enc), '400110')
+class TestStripHeaderAdapter(unittest.TestCase):
+
+ IntegerSteps = StripHeaderAdapter(GreedyBytes, 4, steps = [2,4])
+ Integer = StripHeaderAdapter(GreedyBytes, 4)
+
+ def test_encdec_integer_reverse(self):
+ enc = self.IntegerSteps.build(0x40)
+ self.assertEqual(b2h(enc), '0040')
+ enc = self.IntegerSteps.build(0x3040)
+ self.assertEqual(b2h(enc), '3040')
+ enc = self.IntegerSteps.build(0x203040)
+ self.assertEqual(b2h(enc), '00203040')
+ enc = self.IntegerSteps.build(0x10203040)
+ self.assertEqual(b2h(enc), '10203040')
+
+ enc = self.Integer.build(0x40)
+ self.assertEqual(b2h(enc), '40')
+ enc = self.Integer.build(0x3040)
+ self.assertEqual(b2h(enc), '3040')
+ enc = self.Integer.build(0x203040)
+ self.assertEqual(b2h(enc), '203040')
+ enc = self.Integer.build(0x10203040)
+ self.assertEqual(b2h(enc), '10203040')
+
+
class TestAdapters(unittest.TestCase):
def test_dns_adapter(self):
ad = DnsAdapter(GreedyBytes)
To view, visit change 39193. To unsubscribe, or for help writing mail filters, visit settings.