laforge has submitted this change. ( https://gerrit.osmocom.org/c/python/pyosmocom/+/39193?usp=email )
Change subject: construct: allow stripping of leading zeros with StripHeaderAdapter ......................................................................
construct: allow stripping of leading zeros with StripHeaderAdapter
This patch adds a new class StripHeaderAdapter. This adapter class can be used to strip leading zeros in the same way that StripTrailerAdapter strips trailing zeros.
Related: OS#6679 Change-Id: I1a9fff17abbbef0c5f6d45f58198debfa12e78b6 --- M src/osmocom/construct.py M tests/test_construct.py 2 files changed, 89 insertions(+), 0 deletions(-)
Approvals: Jenkins Builder: Verified laforge: Looks good to me, approved
diff --git a/src/osmocom/construct.py b/src/osmocom/construct.py index d4284d9..7a2851a 100644 --- a/src/osmocom/construct.py +++ b/src/osmocom/construct.py @@ -451,6 +451,48 @@ else: return obj_step_aligned;
+class StripHeaderAdapter(Adapter): + """ + Encoder removes all leading bytes matching the default_value + Decoder pads input data up to total_length with default_value + + In case the encoding restricts the length of the result to specific values, the API user may set those restrictions + using the steps parameter. (e.g. encoded result must be either 1 or 3 byte long, steps would be set to [1,3]) + + This is used in constellations like "FlagsEnum(StripHeaderAdapter(GreedyBytes, 3), ..." + where you have a bit-mask that may have 1, 2 or 3 bytes, depending on whether or not any + of the MSBs are actually set. + """ + def __init__(self, subcon, total_length:int, default_value=b'\x00', min_len=1, steps:typing.List[int]=[]): + super().__init__(subcon) + assert len(default_value) == 1 + self.total_length = total_length + self.default_value = default_value + self.min_len = min_len + self.steps = steps + + def _decode(self, obj, context, path): + assert isinstance(obj, bytes) + # pad with suppressed/missing bytes + if len(obj) < self.total_length: + obj = self.default_value * (self.total_length - len(obj)) + obj + return int.from_bytes(obj, 'big') + + def _encode(self, obj, context, path): + assert isinstance(obj, int) + obj = obj.to_bytes(self.total_length, 'big') + # remove trailing bytes if they are zero + + obj_step_aligned = obj + while len(obj) > self.min_len and obj[0] == self.default_value[0]: + obj = obj[1:] + if len(obj) in self.steps: + obj_step_aligned = obj + + if self.steps == []: + return obj + else: + return obj_step_aligned;
def filter_dict(d, exclude_prefix='_'): """filter the input dict to ensure no keys starting with 'exclude_prefix' remain.""" diff --git a/tests/test_construct.py b/tests/test_construct.py index e4234db..94ae71f 100755 --- a/tests/test_construct.py +++ b/tests/test_construct.py @@ -101,6 +101,9 @@ final_application=0x0200, global_service=0x0100, receipt_generation=0x80, ciphered_load_file_data_block=0x40, contactless_activation=0x20, contactless_self_activation=0x10) + IntegerSteps = StripTrailerAdapter(GreedyBytes, 4, steps = [2,4]) + Integer = StripTrailerAdapter(GreedyBytes, 4) + examples = ['00', '80', '8040', '400010'] def test_encdec(self): for e in self.examples: @@ -108,6 +111,25 @@ reenc = self.Privileges.build(dec) self.assertEqual(e, b2h(reenc))
+ def test_encdec_integer(self): + enc = self.IntegerSteps.build(0x10000000) + self.assertEqual(b2h(enc), '1000') + enc = self.IntegerSteps.build(0x10200000) + self.assertEqual(b2h(enc), '1020') + enc = self.IntegerSteps.build(0x10203000) + self.assertEqual(b2h(enc), '10203000') + enc = self.IntegerSteps.build(0x10203040) + self.assertEqual(b2h(enc), '10203040') + + enc = self.Integer.build(0x10000000) + self.assertEqual(b2h(enc), '10') + enc = self.Integer.build(0x10200000) + self.assertEqual(b2h(enc), '1020') + enc = self.Integer.build(0x10203000) + self.assertEqual(b2h(enc), '102030') + enc = self.Integer.build(0x10203040) + self.assertEqual(b2h(enc), '10203040') + def test_enc(self): enc = self.Privileges.build({'dap_verification' : True}) self.assertEqual(b2h(enc), '40') @@ -124,6 +146,31 @@ self.assertEqual(b2h(enc), '400110')
+class TestStripHeaderAdapter(unittest.TestCase): + + IntegerSteps = StripHeaderAdapter(GreedyBytes, 4, steps = [2,4]) + Integer = StripHeaderAdapter(GreedyBytes, 4) + + def test_encdec_integer_reverse(self): + enc = self.IntegerSteps.build(0x40) + self.assertEqual(b2h(enc), '0040') + enc = self.IntegerSteps.build(0x3040) + self.assertEqual(b2h(enc), '3040') + enc = self.IntegerSteps.build(0x203040) + self.assertEqual(b2h(enc), '00203040') + enc = self.IntegerSteps.build(0x10203040) + self.assertEqual(b2h(enc), '10203040') + + enc = self.Integer.build(0x40) + self.assertEqual(b2h(enc), '40') + enc = self.Integer.build(0x3040) + self.assertEqual(b2h(enc), '3040') + enc = self.Integer.build(0x203040) + self.assertEqual(b2h(enc), '203040') + enc = self.Integer.build(0x10203040) + self.assertEqual(b2h(enc), '10203040') + + class TestAdapters(unittest.TestCase): def test_dns_adapter(self): ad = DnsAdapter(GreedyBytes)