Change in osmocom-bb[master]: trx_toolkit: check in simple declarative codec

laforge gerrit-no-reply at lists.osmocom.org
Mon Apr 12 13:08:52 UTC 2021


laforge has submitted this change. ( https://gerrit.osmocom.org/c/osmocom-bb/+/23135 )

Change subject: trx_toolkit: check in simple declarative codec
......................................................................

trx_toolkit: check in simple declarative codec

Change-Id: I7ff46b278c59af3720ee7f3950ea5a8b2f1313e1
Related: OS#4006, SYS#4895
---
A src/target/trx_toolkit/codec.py
A src/target/trx_toolkit/test_codec.py
2 files changed, 1,000 insertions(+), 0 deletions(-)

Approvals:
  pespin: Looks good to me, but someone else must approve
  laforge: Looks good to me, approved; Verified



diff --git a/src/target/trx_toolkit/codec.py b/src/target/trx_toolkit/codec.py
new file mode 100644
index 0000000..7a42c9b
--- /dev/null
+++ b/src/target/trx_toolkit/codec.py
@@ -0,0 +1,408 @@
+# -*- coding: utf-8 -*-
+
+'''
+Very simple (performance oriented) declarative message codec.
+Inspired by Pycrate and Scapy.
+'''
+
+# TRX Toolkit
+#
+# (C) 2021 by sysmocom - s.f.m.c. GmbH <info at sysmocom.de>
+# Author: Vadim Yanitskiy <vyanitskiy at sysmocom.de>
+#
+# All Rights Reserved
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License along
+# with this program; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+from typing import Optional, Callable, Tuple, Any
+import abc
+
+class ProtocolError(Exception):
+	''' Error in a protocol definition. '''
+
+class DecodeError(Exception):
+	''' Error during decoding of a field/message. '''
+
+class EncodeError(Exception):
+	''' Error during encoding of a field/message. '''
+
+
+class Codec(abc.ABC):
+	''' Base class providing encoding and decoding API. '''
+
+	@abc.abstractmethod
+	def from_bytes(self, vals: dict, data: bytes) -> int:
+		''' Decode value(s) from the given buffer of bytes. '''
+
+	@abc.abstractmethod
+	def to_bytes(self, vals: dict) -> bytes:
+		''' Encode value(s) into bytes. '''
+
+
+class Field(Codec):
+	''' Base class representing one field in a Message. '''
+
+	# Default length (0 means the whole buffer)
+	DEF_LEN: int = 0
+
+	# Default parameters
+	DEF_PARAMS: dict = { }
+
+	# Presence of a field during decoding and encoding
+	get_pres: Callable[[dict], bool]
+	# Length of a field for self.from_bytes()
+	get_len: Callable[[dict, bytes], int]
+	# Value of a field for self.to_bytes()
+	get_val: Callable[[dict], Any]
+
+	def __init__(self, name: str, **kw) -> None:
+		self.name = name
+
+		self.len = kw.get('len', self.DEF_LEN)
+		if self.len == 0: # flexible field
+			self.get_len = lambda _, data: len(data)
+		else: # fixed length
+			self.get_len = lambda vals, _: self.len
+
+		# Field is unconditionally present by default
+		self.get_pres = lambda vals: True
+		# Field takes its value from the given dict by default
+		self.get_val = lambda vals: vals[self.name]
+
+		# Additional parameters for derived field types
+		self.p = { key : kw.get(key, self.DEF_PARAMS[key])
+				for key in self.DEF_PARAMS }
+
+	def from_bytes(self, vals: dict, data: bytes) -> int:
+		if self.get_pres(vals) is False:
+			return 0
+		length = self.get_len(vals, data)
+		if len(data) < length:
+			raise DecodeError('Short read')
+		self._from_bytes(vals, data[:length])
+		return length
+
+	def to_bytes(self, vals: dict) -> bytes:
+		if self.get_pres(vals) is False:
+			return b''
+		data = self._to_bytes(vals)
+		if self.len > 0 and len(data) != self.len:
+			raise EncodeError('Field length mismatch')
+		return data
+
+	@abc.abstractmethod
+	def _from_bytes(self, vals: dict, data: bytes) -> None:
+		''' Decode value(s) from the given buffer of bytes. '''
+		raise NotImplementedError
+
+	@abc.abstractmethod
+	def _to_bytes(self, vals: dict) -> bytes:
+		''' Encode value(s) into bytes. '''
+		raise NotImplementedError
+
+
+class Buf(Field):
+	''' A sequence of octets. '''
+
+	def _from_bytes(self, vals: dict, data: bytes) -> None:
+		vals[self.name] = data
+
+	def _to_bytes(self, vals: dict) -> bytes:
+		# TODO: handle len(self.get_val()) < self.get_len()
+		return self.get_val(vals)
+
+
+class Spare(Field):
+	''' Spare filling for RFU fields or padding. '''
+
+	# Default parameters
+	DEF_PARAMS = {
+		'filler'	: b'\x00',
+	}
+
+	def _from_bytes(self, vals: dict, data: bytes) -> None:
+		pass # Just ignore it
+
+	def _to_bytes(self, vals: dict) -> bytes:
+		return self.p['filler'] * self.get_len(vals, b'')
+
+
+class Uint(Field):
+	''' An integer field: unsigned, N bits, big endian. '''
+
+	# Uint8 by default
+	DEF_LEN = 1
+
+	# Default parameters
+	DEF_PARAMS = {
+		'offset'	: 0,
+		'mult'		: 1,
+	}
+
+	# Big endian, unsigned
+	SIGN = False
+	BO = 'big'
+
+	def _from_bytes(self, vals: dict, data: bytes) -> None:
+		val = int.from_bytes(data, self.BO, signed=self.SIGN)
+		vals[self.name] = val * self.p['mult'] + self.p['offset']
+
+	def _to_bytes(self, vals: dict) -> bytes:
+		val = (self.get_val(vals) - self.p['offset']) // self.p['mult']
+		return val.to_bytes(self.len, self.BO, signed=self.SIGN)
+
+class Uint16BE(Uint):
+	DEF_LEN = 16 // 8
+
+class Uint16LE(Uint16BE):
+	BO = 'little'
+
+class Uint32BE(Uint):
+	DEF_LEN = 32 // 8
+
+class Uint32LE(Uint32BE):
+	BO = 'little'
+
+class Int(Uint):
+	SIGN = True
+
+class Int16BE(Int):
+	DEF_LEN = 16 // 8
+
+class Int16LE(Int16BE):
+	BO = 'little'
+
+class Int32BE(Int):
+	DEF_LEN = 32 // 8
+
+class Int32LE(Int32BE):
+	BO = 'little'
+
+
+class BitFieldSet(Field):
+	''' A set of bit-fields. '''
+
+	# Default parameters
+	DEF_PARAMS = {
+		# Default field order (MSB first)
+		'order'		: 'big',
+	}
+
+	# To be defined by derived types
+	STRUCT: Tuple['BitField', ...] = ()
+
+	def __init__(self, **kw) -> None:
+		Field.__init__(self, self.__class__.__name__, **kw)
+
+		self._fields = kw.get('set', self.STRUCT)
+		if type(self._fields) is not tuple:
+			raise ProtocolError('Expected a tuple')
+
+		# LSB first is basically reversed order
+		if self.p['order'] in ('little', 'lsb'):
+			self._fields = self._fields[::-1]
+
+		# Calculate the overall field length
+		if self.len == 0:
+			bl_sum = sum([f.bl for f in self._fields])
+			self.len = bl_sum // 8
+			if bl_sum % 8 > 0:
+				self.len += 1
+
+		# Re-define self.get_len() since we always know the length
+		self.get_len = lambda vals, data: self.len
+
+		# Pre-calculate offset and mask for each field
+		offset = self.len * 8
+		for f in self._fields:
+			if f.bl > offset:
+				raise ProtocolError(f, 'BitFieldSet overflow')
+			f.offset = offset - f.bl
+			f.mask = 2 ** f.bl - 1
+			offset -= f.bl
+
+	def _from_bytes(self, vals: dict, data: bytes) -> None:
+		blob = int.from_bytes(data, byteorder='big') # intentionally using 'big' here
+		for f in self._fields:
+			f.dec_val(vals, blob)
+
+	def _to_bytes(self, vals: dict) -> bytes:
+		blob = 0x00
+		for f in self._fields: # TODO: use functools.reduce()?
+			blob |= f.enc_val(vals)
+		return blob.to_bytes(self.len, byteorder='big')
+
+class BitField:
+	''' One field in a BitFieldSet. '''
+
+	# Special fields for BitFieldSet
+	offset: int = 0
+	mask: int = 0
+
+	class Spare:
+		''' Spare filling in a BitFieldSet. '''
+
+		def __init__(self, bl: int) -> None:
+			self.name = None
+			self.bl = bl
+
+		def enc_val(self, vals: dict) -> int:
+			return 0
+
+		def dec_val(self, vals: dict, blob: int) -> None:
+			pass # Just ignore it
+
+	def __init__(self, name: str, bl: int, **kw) -> None:
+		if bl < 1: # Ensure proper length
+			raise ProtocolError('Incorrect bit-field length')
+
+		self.name = name
+		self.bl = bl
+
+		# (Optional) fixed value for encoding and decoding
+		self.val: Optional[int] = kw.get('val', None)
+
+	def enc_val(self, vals: dict) -> int:
+		if self.val is None:
+			val = vals[self.name]
+		else:
+			val = self.val
+		return (val & self.mask) << self.offset
+
+	def dec_val(self, vals: dict, blob: int) -> None:
+		vals[self.name] = (blob >> self.offset) & self.mask
+		if (self.val is not None) and (vals[self.name] != self.val):
+			raise DecodeError('Unexpected value %d, expected %d'
+				% (vals[self.name], self.val))
+
+
+class Envelope:
+	''' A group of related fields. '''
+
+	STRUCT: Tuple[Codec, ...] = ()
+
+	def __init__(self, check_len: bool = True):
+		# TODO: ensure uniqueue field names in self.STRUCT
+		self.c: dict = { }
+		self.check_len = check_len
+
+	def __getitem__(self, key: str) -> Any:
+		return self.c[key]
+
+	def __setitem__(self, key: str, val: Any) -> None:
+		self.c[key] = val
+
+	def __delitem__(self, key: str) -> None:
+		del self.c[key]
+
+	def check(self, vals: dict) -> None:
+		''' Check the content before encoding and after decoding.
+		    Raise exceptions (e.g. ValueError) if something is wrong.
+
+		    Do not assert for every possible error (e.g. a negative value
+		    for a Uint field) if an exception will be thrown by the field's
+		    to_bytes() method anyway.  Only additional constraints here.
+		'''
+
+	def from_bytes(self, data: bytes) -> int:
+		self.c.clear() # forget the old content
+		return self._from_bytes(self.c, data)
+
+	def to_bytes(self) -> bytes:
+		return self._to_bytes(self.c)
+
+	def _from_bytes(self, vals: dict, data: bytes, offset: int = 0) -> int:
+		try: # Fields throw exceptions
+			for f in self.STRUCT:
+				offset += f.from_bytes(vals, data[offset:])
+		except Exception as e:
+			# Add contextual info
+			raise DecodeError(self, f, offset) from e
+		if self.check_len and len(data) != offset:
+			raise DecodeError(self, 'Unhandled tail octets: %s'
+						% data[offset:].hex())
+		self.check(vals) # Check the content after decoding (raises exceptions)
+		return offset
+
+	def _to_bytes(self, vals: dict) -> bytes:
+		def proc(f: Codec):
+			try: # Fields throw exceptions
+				return f.to_bytes(vals)
+			except Exception as e:
+				# Add contextual info
+				raise EncodeError(self, f) from e
+		self.check(vals) # Check the content before encoding (raises exceptions)
+		return b''.join([proc(f) for f in self.STRUCT])
+
+	class F(Field):
+		''' Field wrapper. '''
+
+		def __init__(self, e: 'Envelope', name: str, **kw) -> None:
+			Field.__init__(self, name, **kw)
+			self.e = e
+
+		def _from_bytes(self, vals: dict, data: bytes) -> None:
+			vals[self.name] = { }
+			self.e._from_bytes(vals[self.name], data)
+
+		def _to_bytes(self, vals: dict) -> bytes:
+			return self.e._to_bytes(self.get_val(vals))
+
+	def f(self, name: str, **kw) -> Field:
+		return self.F(self, name, **kw)
+
+
+class Sequence:
+	''' A sequence of repeating elements (e.g. TLVs). '''
+
+	# The item of sequence
+	ITEM: Optional[Envelope] = None
+
+	def __init__(self, **kw) -> None:
+		if (self.ITEM is None) and ('item' not in kw):
+			raise ProtocolError('Missing Sequence item')
+		self._item = kw.get('item', self.ITEM) # type: Envelope
+		self._item.check_len = False
+
+	def from_bytes(self, data: bytes) -> list:
+		proc = self._item._from_bytes
+		vseq, offset = [], 0
+		length = len(data)
+
+		while offset < length:
+			vseq.append({ }) # new item of sequence
+			offset += proc(vseq[-1], data[offset:])
+
+		return vseq
+
+	def to_bytes(self, vseq: list) -> bytes:
+		proc = self._item._to_bytes
+		return b''.join([proc(v) for v in vseq])
+
+	class F(Field):
+		''' Field wrapper. '''
+
+		def __init__(self, s: 'Sequence', name: str, **kw) -> None:
+			Field.__init__(self, name, **kw)
+			self.s = s
+
+		def _from_bytes(self, vals: dict, data: bytes) -> None:
+			vals[self.name] = self.s.from_bytes(data)
+
+		def _to_bytes(self, vals: dict) -> bytes:
+			return self.s.to_bytes(self.get_val(vals))
+
+	def f(self, name: str, **kw) -> Field:
+		return self.F(self, name, **kw)
diff --git a/src/target/trx_toolkit/test_codec.py b/src/target/trx_toolkit/test_codec.py
new file mode 100644
index 0000000..e0649d8
--- /dev/null
+++ b/src/target/trx_toolkit/test_codec.py
@@ -0,0 +1,592 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+'''
+Unit tests for declarative message codec.
+'''
+
+# (C) 2021 by sysmocom - s.f.m.c. GmbH <info at sysmocom.de>
+# Author: Vadim Yanitskiy <vyanitskiy at sysmocom.de>
+#
+# All Rights Reserved
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License along
+# with this program; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+import unittest
+import struct
+import codec
+
+class TestField(codec.Field):
+	DEF_PARAMS = { 'key' : 0xde }
+	DEF_LEN = 4
+
+	@staticmethod
+	def xor(data: bytes, key: int = 0x00):
+		return bytes([x ^ key for x in data])
+
+	def _from_bytes(self, vals: dict, data: bytes) -> None:
+		vals[self.name] = self.xor(data, self.p['key'])
+
+	def _to_bytes(self, vals: dict) -> bytes:
+		return self.xor(self.get_val(vals), self.p['key'])
+
+class Field(unittest.TestCase):
+	MAGIC = b'\xde\xad\xbe\xef'
+
+	def test_to_bytes(self):
+		vals = { 'magic' : self.MAGIC, 'other' : 'unrelated' }
+		encoded_de = TestField.xor(self.MAGIC, 0xde)
+		encoded_88 = TestField.xor(self.MAGIC, 0x88)
+
+		with self.subTest('default length=4, default key=0xde'):
+			field = TestField('magic')
+			self.assertEqual(field.to_bytes(vals), encoded_de)
+
+		with self.subTest('different length=2, default key=0xde'):
+			field = TestField('magic', len=2)
+			vals['magic'] = vals['magic'][:2]
+			self.assertEqual(field.to_bytes(vals), encoded_de[:2])
+
+		with self.subTest('EncodeError due to length mismatch'):
+			field = TestField('magic', len=8)
+			with self.assertRaises(codec.EncodeError):
+				field.to_bytes(vals)
+
+	def test_from_bytes(self):
+		encoded_de = TestField.xor(self.MAGIC, 0xde) + b'\xff' * 60
+		encoded_88 = TestField.xor(self.MAGIC, 0x88) + b'\xff' * 60
+		vals = { 'magic' : 'overrien', 'other' : 'unchanged' }
+
+		with self.subTest('default length=4, default key=0xde'):
+			field = TestField('magic')
+			offset = field.from_bytes(vals, encoded_de)
+			self.assertEqual(vals['other'], 'unchanged')
+			self.assertEqual(vals['magic'], self.MAGIC)
+			self.assertEqual(offset, len(self.MAGIC))
+
+		with self.subTest('default length=4, different key=0x88'):
+			field = TestField('magic', key=0x88)
+			offset = field.from_bytes(vals, encoded_88)
+			self.assertEqual(vals['other'], 'unchanged')
+			self.assertEqual(vals['magic'], self.MAGIC)
+			self.assertEqual(offset, len(self.MAGIC))
+
+		with self.subTest('different length=2, default key=0xde'):
+			field = TestField('magic', len=2)
+			offset = field.from_bytes(vals, encoded_de)
+			self.assertEqual(vals['other'], 'unchanged')
+			self.assertEqual(vals['magic'], self.MAGIC[:2])
+			self.assertEqual(offset, 2)
+
+		with self.subTest('full length, different key=0x88'):
+			field = TestField('magic', len=0, key=0x88)
+			offset = field.from_bytes(vals, encoded_88)
+			self.assertEqual(vals['other'], 'unchanged')
+			self.assertEqual(vals['magic'], self.MAGIC + b'\x77' * 60)
+			self.assertEqual(offset, len(encoded_88))
+
+		with self.subTest('DecodeError due to short read'):
+			field = TestField('magic', len=4)
+			with self.assertRaises(codec.DecodeError):
+				field.from_bytes(vals, b'\x00')
+
+	def test_get_pres(self):
+		vals = { 'magic' : self.MAGIC }
+
+		with self.subTest('to_bytes() for a non-existing field'):
+			field = TestField('not-there')
+			with self.assertRaises(KeyError):
+				field.to_bytes(vals)
+
+		with self.subTest('to_bytes() for a field with get_pres()'):
+			field = TestField('magic', key=0x00)
+			field.get_pres = lambda v: not v['omit']
+
+			data = field.to_bytes({ **vals, 'omit' : False })
+			self.assertEqual(data, self.MAGIC)
+
+			data = field.to_bytes({ **vals, 'omit' : True })
+			self.assertEqual(data, b'')
+
+		with self.subTest('from_bytes() for a field with get_pres()'):
+			field = TestField('magic', key=0x00)
+			field.get_pres = lambda v: not v['omit']
+
+			vals = { 'omit' : False }
+			offset = field.from_bytes(vals, self.MAGIC)
+			self.assertEqual(vals['magic'], self.MAGIC)
+			self.assertEqual(offset, len(self.MAGIC))
+
+			vals = { 'omit' : True }
+			offset = field.from_bytes(vals, self.MAGIC)
+			self.assertFalse('magic' in vals)
+			self.assertEqual(offset, 0)
+
+	def test_get_len(self):
+		vals = { 'len' : 32, 'unrelated' : 'foo' }
+
+		field = TestField('magic', key=0x00)
+		field.get_len = lambda v, _: v['len']
+
+		with self.subTest('not enough octets in the buffer: 16 < 32'):
+			with self.assertRaises(codec.DecodeError):
+				field.from_bytes(vals, b'\xff' * 16)
+
+		with self.subTest('more than enough octets in the buffer'):
+			offset = field.from_bytes(vals, b'\xff' * 64)
+			self.assertEqual(vals['magic'], b'\xff' * 32)
+			self.assertEqual(offset, 32)
+
+		with self.subTest('length field does not exist'):
+			with self.assertRaises(KeyError):
+				field.from_bytes({ }, b'\xff' * 64)
+
+	def test_get_val(self):
+		field = TestField('magic', key=0x00, len=0)
+		field.get_val = lambda v: v.get('val', self.MAGIC)
+
+		with self.subTest('value is present in the dict'):
+			data = field.to_bytes({ 'val' : b'\xd0\xde' })
+			self.assertEqual(data, b'\xd0\xde')
+
+		with self.subTest('value is not present in the dict'):
+			data = field.to_bytes({ })
+			self.assertEqual(data, self.MAGIC)
+
+class Buf(unittest.TestCase):
+	MAGIC = b'\xde\xad' * 4
+
+	def test_to_bytes(self):
+		vals = { 'buf' : self.MAGIC }
+
+		with self.subTest('with no length constraints'):
+			field = codec.Buf('buf') # default: len=0
+			self.assertEqual(field.to_bytes(vals), self.MAGIC)
+
+		with self.subTest('with length constraints'):
+			field = codec.Buf('buf', len=len(self.MAGIC))
+			self.assertEqual(field.to_bytes(vals), self.MAGIC)
+
+		with self.subTest('EncodeError due to length mismatch'):
+			field = codec.Buf('buf', len=4)
+			with self.assertRaises(codec.EncodeError):
+				field.to_bytes(vals)
+
+	def test_from_bytes(self):
+		vals = { }
+
+		with self.subTest('with no length constraints'):
+			field = codec.Buf('buf') # default: len=0
+			offset = field.from_bytes(vals, self.MAGIC)
+			self.assertEqual(vals['buf'], self.MAGIC)
+			self.assertEqual(offset, len(self.MAGIC))
+
+		with self.subTest('with length constraints'):
+			field = codec.Buf('buf', len=2)
+			offset = field.from_bytes(vals, self.MAGIC)
+			self.assertEqual(vals['buf'], self.MAGIC[:2])
+			self.assertEqual(offset, len(self.MAGIC[:2]))
+
+		with self.subTest('DecodeError due to not enough bytes'):
+			field = codec.Buf('buf', len=64)
+			with self.assertRaises(codec.DecodeError):
+				field.from_bytes(vals, self.MAGIC)
+
+class Spare(unittest.TestCase):
+	# Fixed length with custom filler
+	SAA = codec.Spare('pad', len=4, filler=b'\xaa')
+	# Auto-calculated length with custom filler
+	SFF = codec.Spare('pad', filler=b'\xff')
+	SFF.get_len = lambda v, _: v['len']
+	# Fixed length with default filler
+	S00 = codec.Spare('pad', len=2)
+
+	def test_to_bytes(self):
+		self.assertEqual(self.SFF.to_bytes({ 'len' : 8 }), b'\xff' * 8)
+		self.assertEqual(self.SAA.to_bytes({ }), b'\xaa' * 4)
+		self.assertEqual(self.S00.to_bytes({ }), b'\x00' * 2)
+
+	def test_from_bytes(self):
+		with self.assertRaises(codec.DecodeError):
+			self.S00.from_bytes({ }, b'\x00') # Short read
+		self.assertEqual(self.SFF.from_bytes({ 'len' : 8 }, b'\xff' * 8), 8)
+		self.assertEqual(self.SAA.from_bytes({ }, b'\xaa' * 64), 4)
+		self.assertEqual(self.S00.from_bytes({ }, b'\x00' * 64), 2)
+
+class Uint(unittest.TestCase):
+	def _test_uint(self, field, fmt, vals):
+		for i in vals:
+			with self.subTest('to_bytes()'):
+				val = field.to_bytes({ field.name : i })
+				self.assertEqual(val, struct.pack(fmt, i))
+
+			with self.subTest('from_bytes()'):
+				data, parsed = struct.pack(fmt, i), { }
+				offset = field.from_bytes(parsed, data)
+				self.assertEqual(offset, len(data))
+				self.assertEqual(parsed[field.name], i)
+
+	def test_uint8(self):
+		self._test_uint(codec.Uint('foo'), 'B', range(2 ** 8))
+
+	def test_int8(self):
+		self._test_uint(codec.Int('foo'), 'b', range(-128, 128))
+
+	def test_uint16(self):
+		vals = (0, 65, 128, 255, 512, 1023, 2 ** 16 - 1)
+		self._test_uint(codec.Uint16BE('foo'), '>H', vals)
+		self._test_uint(codec.Uint16LE('foo'), '<H', vals)
+
+	def test_int16(self):
+		vals = (-32767, -16384, 0, 16384, 32767)
+		self._test_uint(codec.Int16BE('foo'), '>h', vals)
+		self._test_uint(codec.Int16LE('foo'), '<h', vals)
+
+	def test_uint32(self):
+		vals = (0, 33, 255, 1024, 1337, 4099, 2 ** 32 - 1)
+		self._test_uint(codec.Uint32BE('foo'), '>I', vals)
+		self._test_uint(codec.Uint32LE('foo'), '<I', vals)
+
+	def test_int32(self):
+		vals = (-2147483647, 0, 2147483647)
+		self._test_uint(codec.Int32BE('foo'), '>i', vals)
+		self._test_uint(codec.Int32LE('foo'), '<i', vals)
+
+	def test_offset_mult(self):
+		with self.subTest('encode / decode with offset=5'):
+			field = codec.Uint('foo', offset=5)
+
+			self.assertEqual(field.to_bytes({ 'foo' : 10 }), b'\x05')
+			self.assertEqual(field.to_bytes({ 'foo' :  5 }), b'\x00')
+
+			vals = { 'foo' : 'overriden' }
+			field.from_bytes(vals, b'\xff')
+			self.assertEqual(vals['foo'], 260)
+			field.from_bytes(vals, b'\x00')
+			self.assertEqual(vals['foo'], 5)
+
+		with self.subTest('encode / decode with mult=2'):
+			field = codec.Uint('foo', mult=2)
+
+			self.assertEqual(field.to_bytes({ 'foo' : 0 }), b'\x00')
+			self.assertEqual(field.to_bytes({ 'foo' : 3 }), b'\x01')
+			self.assertEqual(field.to_bytes({ 'foo' : 32 }), b'\x10')
+			self.assertEqual(field.to_bytes({ 'foo' : 64 }), b'\x20')
+
+			vals = { 'foo' : 'overriden' }
+			field.from_bytes(vals, b'\x00')
+			self.assertEqual(vals['foo'], 0 * 2)
+			field.from_bytes(vals, b'\x0f')
+			self.assertEqual(vals['foo'], 15 * 2)
+			field.from_bytes(vals, b'\xff')
+			self.assertEqual(vals['foo'], 255 * 2)
+
+class BitFieldSet(unittest.TestCase):
+	S16 = codec.BitFieldSet(set=(
+		codec.BitField('f4a', bl=4),
+		codec.BitField('f8', bl=8),
+		codec.BitField('f4b', bl=4),
+	))
+
+	S8M = codec.BitFieldSet(order='msb', set=(
+		codec.BitField('f4', bl=4),
+		codec.BitField('f1', bl=1),
+		codec.BitField('f3', bl=3),
+	))
+
+	S8L = codec.BitFieldSet(order='lsb', set=(
+		codec.BitField('f4', bl=4),
+		codec.BitField('f1', bl=1),
+		codec.BitField('f3', bl=3),
+	))
+
+	S8V = codec.BitFieldSet(set=(
+		codec.BitField('f4', bl=4, val=2),
+		codec.BitField('f1', bl=1, val=0),
+		codec.BitField('f3', bl=3),
+	))
+
+	S8P = codec.BitFieldSet(set=(
+		codec.BitField.Spare(bl=4),
+		codec.BitField('f4', bl=4),
+	))
+
+	@staticmethod
+	def from_bytes(s: codec.BitFieldSet, data: bytes) -> dict:
+		vals = { }
+		s.from_bytes(vals, data)
+		return vals
+
+	def test_len_auto(self):
+		with self.subTest('1 + 2 = 3 bits => 1 octet (with padding)'):
+			s = codec.BitFieldSet(set=(
+				codec.BitField('f1', bl=1),
+				codec.BitField('f2', bl=2),
+			))
+			self.assertEqual(s.len, 1)
+
+		with self.subTest('4 + 2 + 2 = 8 bits => 1 octet'):
+			s = codec.BitFieldSet(set=(
+				codec.BitField('f4', bl=4),
+				codec.BitField('f2a', bl=2),
+				codec.BitField('f2b', bl=2),
+			))
+			self.assertEqual(s.len, 1)
+
+		with self.subTest('12 + 4 + 2 = 18 bits => 3 octets (with padding)'):
+			s = codec.BitFieldSet(set=(
+				codec.BitField('f12', bl=12),
+				codec.BitField('f4', bl=4),
+				codec.BitField('f2', bl=2),
+			))
+			self.assertEqual(s.len, 3)
+
+	def test_overflow(self):
+		with self.assertRaises(codec.ProtocolError):
+			s = codec.BitFieldSet(len=1, set=(
+				codec.BitField('f6', bl=6),
+				codec.BitField('f4', bl=4),
+			))
+
+	def test_offset_mask(self):
+		calc = lambda s: [(f.name, f.offset, f.mask) for f in s._fields]
+
+		with self.subTest('16 bit total (MSB): f4a + f8 + f4b'):
+			om = [('f4a', 8 + 4, 0x0f), ('f8', 4, 0xff), ('f4b', 0, 0x0f)]
+			self.assertEqual(len(self.S16._fields), 3)
+			self.assertEqual(calc(self.S16), om)
+
+		with self.subTest('8 bit total (MSB): f4 + f1 + f3'):
+			om = [('f4', 1 + 3, 0x0f), ('f1', 3, 0x01), ('f3', 0, 0x07)]
+			self.assertEqual(len(self.S8M._fields), 3)
+			self.assertEqual(calc(self.S8M), om)
+
+		with self.subTest('8 bit total (LSB): f4 + f1 + f3'):
+			om = [('f3', 1 + 4, 0x07), ('f1', 4, 0x01), ('f4', 0, 0x0f)]
+			self.assertEqual(len(self.S8L._fields), 3)
+			self.assertEqual(calc(self.S8L), om)
+
+		with self.subTest('8 bit total (LSB): s4 + f4'):
+			om = [(None, 4, 0x0f), ('f4', 0, 0x0f)]
+			self.assertEqual(len(self.S8P._fields), 2)
+			self.assertEqual(calc(self.S8P), om)
+
+	def test_to_bytes(self):
+		with self.subTest('16 bit total (MSB): f4a + f8 + f4b'):
+			vals = { 'f4a' : 0x0f, 'f8' : 0xff, 'f4b' : 0x0f }
+			self.assertEqual(self.S16.to_bytes(vals), b'\xff\xff')
+			vals = { 'f4a' : 0x00, 'f8' : 0x00, 'f4b' : 0x00 }
+			self.assertEqual(self.S16.to_bytes(vals), b'\x00\x00')
+			vals = { 'f4a' : 0x0f, 'f8' : 0x00, 'f4b' : 0x0f }
+			self.assertEqual(self.S16.to_bytes(vals), b'\xf0\x0f')
+			vals = { 'f4a' : 0x00, 'f8' : 0xff, 'f4b' : 0x00 }
+			self.assertEqual(self.S16.to_bytes(vals), b'\x0f\xf0')
+
+		with self.subTest('8 bit total (MSB): f4 + f1 + f3'):
+			vals = { 'f4' : 0x0f, 'f1' : 0x01, 'f3' : 0x07 }
+			self.assertEqual(self.S8M.to_bytes(vals), b'\xff')
+			vals = { 'f4' : 0x00, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.S8M.to_bytes(vals), b'\x00')
+			vals = { 'f4' : 0x0f, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.S8M.to_bytes(vals), b'\xf0')
+
+		with self.subTest('8 bit total (LSB): f4 + f1 + f3'):
+			vals = { 'f4' : 0x0f, 'f1' : 0x01, 'f3' : 0x07 }
+			self.assertEqual(self.S8L.to_bytes(vals), b'\xff')
+			vals = { 'f4' : 0x00, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.S8L.to_bytes(vals), b'\x00')
+			vals = { 'f4' : 0x0f, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.S8L.to_bytes(vals), b'\x0f')
+
+	def test_from_bytes(self):
+		pad = b'\xff' * 64
+
+		with self.subTest('16 bit total (MSB): f4a + f8 + f4b'):
+			vals = { 'f4a' : 0x0f, 'f8' : 0xff, 'f4b' : 0x0f }
+			self.assertEqual(self.from_bytes(self.S16, b'\xff\xff' + pad), vals)
+			vals = { 'f4a' : 0x00, 'f8' : 0x00, 'f4b' : 0x00 }
+			self.assertEqual(self.from_bytes(self.S16, b'\x00\x00' + pad), vals)
+			vals = { 'f4a' : 0x0f, 'f8' : 0x00, 'f4b' : 0x0f }
+			self.assertEqual(self.from_bytes(self.S16, b'\xf0\x0f' + pad), vals)
+			vals = { 'f4a' : 0x00, 'f8' : 0xff, 'f4b' : 0x00 }
+			self.assertEqual(self.from_bytes(self.S16, b'\x0f\xf0' + pad), vals)
+
+		with self.subTest('8 bit total (MSB): f4 + f1 + f3'):
+			vals = { 'f4' : 0x0f, 'f1' : 0x01, 'f3' : 0x07 }
+			self.assertEqual(self.from_bytes(self.S8M, b'\xff' + pad), vals)
+			vals = { 'f4' : 0x00, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.from_bytes(self.S8M, b'\x00' + pad), vals)
+			vals = { 'f4' : 0x0f, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.from_bytes(self.S8M, b'\xf0' + pad), vals)
+
+		with self.subTest('8 bit total (LSB): f4 + f1 + f3'):
+			vals = { 'f4' : 0x0f, 'f1' : 0x01, 'f3' : 0x07 }
+			self.assertEqual(self.from_bytes(self.S8L, b'\xff' + pad), vals)
+			vals = { 'f4' : 0x00, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.from_bytes(self.S8L, b'\x00' + pad), vals)
+			vals = { 'f4' : 0x0f, 'f1' : 0x00, 'f3' : 0x00 }
+			self.assertEqual(self.from_bytes(self.S8L, b'\x0f' + pad), vals)
+
+	def test_to_bytes_val(self):
+		with self.subTest('fixed values in absence of user-supplied values'):
+			vals = { 'f3' : 0x00 } # | { 'f4' : 2, 'f1' : 0 }
+			self.assertEqual(self.S8V.to_bytes(vals), b'\x20')
+
+		with self.subTest('fixed values take precedence'):
+			vals = { 'f4' : 1, 'f1' : 1, 'f3' : 0 }
+			self.assertEqual(self.S8V.to_bytes(vals), b'\x20')
+
+	def test_from_bytes_val(self):
+		with self.assertRaises(codec.DecodeError):
+			self.S8V.from_bytes({ }, b'\xf0') # 'f4': 15 vs 2
+
+		with self.assertRaises(codec.DecodeError):
+			self.S8V.from_bytes({ }, b'\x08') # 'f1': 1 vs 0
+
+		# Field 'f3' takes any value, no exceptions shall be raised
+		for i in range(8):
+			data, vals = bytes([0x20 + i]), { 'f4' : 2, 'f1' : 0, 'f3' : i }
+			self.assertEqual(self.from_bytes(self.S8V, data), vals)
+
+	def test_to_bytes_spare(self):
+		self.assertEqual(self.S8P.to_bytes({ 'f4' : 0x00 }), b'\x00')
+		self.assertEqual(self.S8P.to_bytes({ 'f4' : 0x0f }), b'\x0f')
+		self.assertEqual(self.S8P.to_bytes({ 'f4' : 0xff }), b'\x0f')
+
+	def test_from_bytes_spare(self):
+		self.assertEqual(self.from_bytes(self.S8P, b'\x00'), { 'f4' : 0x00 })
+		self.assertEqual(self.from_bytes(self.S8P, b'\x0f'), { 'f4' : 0x0f })
+		self.assertEqual(self.from_bytes(self.S8P, b'\xff'), { 'f4' : 0x0f })
+
+class TestPDU(codec.Envelope):
+	STRUCT = (
+		codec.BitFieldSet(len=2, set=(
+			codec.BitField('ver', bl=4),
+			codec.BitField('flag', bl=1),
+		)),
+		codec.Uint16BE('len'),
+		codec.Buf('data'),
+		codec.Buf('tail', len=2),
+	)
+
+	def __init__(self, *args, **kw):
+		codec.Envelope.__init__(self, *args, **kw)
+		self.STRUCT[-3].get_val = lambda v: len(v['data'])
+		self.STRUCT[-2].get_len = lambda v, _: v['len']
+		self.STRUCT[-1].get_pres = lambda v: bool(v['flag'])
+
+	def check(self, vals: dict) -> None:
+		if not vals['ver'] in (0, 1, 2):
+			raise ValueError('Unknown version %d' % vals['ver'])
+
+class Envelope(unittest.TestCase):
+	def test_rest_octets(self):
+		pdu = TestPDU(check_len=False)
+		pdu.from_bytes(b'\x00' * 64)
+
+		with self.assertRaises(codec.DecodeError):
+			pdu = TestPDU(check_len=True)
+			pdu.from_bytes(b'\x00' * 64) # 'len' : 0
+
+	def test_field_raises(self):
+		pdu = TestPDU()
+		with self.assertRaises(codec.EncodeError):
+			pdu.c = { 'ver' : 0, 'flag' : 1, 'data' : b'\xff' * 16 }
+			pdu.to_bytes() # KeyError: 'tail' not found
+
+	def test_to_bytes(self):
+		pdu = TestPDU()
+
+		# No content in the new instances
+		self.assertEqual(pdu.c, { })
+
+		pdu.c = { 'ver' : 0, 'flag' : 1, 'data' : b'', 'tail' : b'\xde\xbe' }
+		self.assertEqual(pdu.to_bytes(), b'\x08\x00\x00\x00' + b'\xde\xbe')
+
+		pdu.c = { 'ver' : 1, 'flag' : 0, 'data' : b'\xff' * 15 }
+		self.assertEqual(pdu.to_bytes(), b'\x10\x00\x00\x0f' + b'\xff' * 15)
+
+		pdu.c = { 'ver' : 2, 'flag' : 1, 'data' : b'\xf0', 'tail' : b'\xbe\xed' }
+		self.assertEqual(pdu.to_bytes(), b'\x28\x00\x00\x01\xf0\xbe\xed')
+
+	def test_from_bytes(self):
+		pdu = TestPDU()
+
+		# No content in the new instances
+		self.assertEqual(pdu.c, { })
+
+		c = { 'ver' : 0, 'flag' : 1, 'len' : 0, 'data' : b'', 'tail' : b'\xde\xbe' }
+		pdu.from_bytes(b'\x08\x00\x00\x00' + b'\xde\xbe')
+		self.assertEqual(pdu.c, c)
+
+		c = { 'ver' : 1, 'flag' : 0, 'len' : 15, 'data' : b'\xff' * 15 }
+		pdu.from_bytes(b'\x10\x00\x00\x0f' + b'\xff' * 15)
+		self.assertEqual(pdu.c, c)
+
+		c = { 'ver' : 2, 'flag' : 1, 'len' : 1, 'data' : b'\xf0', 'tail' : b'\xbe\xed' }
+		pdu.from_bytes(b'\x28\x00\x00\x01\xf0\xbe\xed')
+		self.assertEqual(pdu.c, c)
+
+	def test_to_bytes_check(self):
+		pdu = TestPDU()
+
+		pdu.c = { 'ver' : 8, 'flag' : 1, 'data' : b'', 'tail' : b'\xde\xbe' }
+		with self.assertRaises(ValueError):
+			pdu.to_bytes()
+
+	def test_from_bytes_check(self):
+		pdu = TestPDU()
+
+		with self.assertRaises(ValueError):
+			pdu.from_bytes(b'\xf0\x00\x00\x00')
+
+class Sequence(unittest.TestCase):
+	class TLV(codec.Envelope):
+		STRUCT = (
+			codec.Uint('T'),
+			codec.Uint('L'),
+			codec.Buf('V'),
+		)
+
+		def __init__(self, *args, **kw) -> None:
+			codec.Envelope.__init__(self, *args, **kw)
+			self.STRUCT[-2].get_val = lambda v: len(v['V'])
+			self.STRUCT[-1].get_len = lambda v, _: v['L']
+
+	# Sequence of TLVs
+	SEQ = codec.Sequence(item=TLV())
+
+	Vseq, Bseq = [
+		{ 'T' : 0xde, 'L' : 4, 'V' : b'\xde\xad\xbe\xef' },
+		{ 'T' : 0xbe, 'L' : 2, 'V' : b'\xbe\xef' },
+		{ 'T' : 0xbe, 'L' : 2, 'V' : b'\xef\xbe' },
+		{ 'T' : 0x00, 'L' : 0, 'V' : b'' },
+	], b''.join([
+		b'\xde\x04\xde\xad\xbe\xef',
+		b'\xbe\x02\xbe\xef',
+		b'\xbe\x02\xef\xbe',
+		b'\x00\x00',
+	])
+
+	def test_to_bytes(self):
+		res = self.SEQ.to_bytes(self.Vseq)
+		self.assertEqual(res, self.Bseq)
+
+	def test_from_bytes(self):
+		res = self.SEQ.from_bytes(self.Bseq)
+		self.assertEqual(res, self.Vseq)
+
+if __name__ == '__main__':
+	unittest.main()

-- 
To view, visit https://gerrit.osmocom.org/c/osmocom-bb/+/23135
To unsubscribe, or for help writing mail filters, visit https://gerrit.osmocom.org/settings

Gerrit-Project: osmocom-bb
Gerrit-Branch: master
Gerrit-Change-Id: I7ff46b278c59af3720ee7f3950ea5a8b2f1313e1
Gerrit-Change-Number: 23135
Gerrit-PatchSet: 4
Gerrit-Owner: fixeria <vyanitskiy at sysmocom.de>
Gerrit-Reviewer: Jenkins Builder
Gerrit-Reviewer: laforge <laforge at osmocom.org>
Gerrit-Reviewer: pespin <pespin at sysmocom.de>
Gerrit-Reviewer: tnt <tnt at 246tNt.com>
Gerrit-MessageType: merged
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.osmocom.org/pipermail/gerrit-log/attachments/20210412/d75a7c11/attachment.htm>


More information about the gerrit-log mailing list