diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 4295cedeec..fb584bb748 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -90,6 +90,10 @@ def force_str(force: bool) -> str: class TypeInfo(ABC): """Base class for all type information.""" + # Maximum size for fixed-size types (in bytes) + # -1 for variable length types + maximum_size = -1 + def __init__(self, field: descriptor.FieldDescriptorProto) -> None: self._field = field @@ -280,6 +284,7 @@ class DoubleType(TypeInfo): decode_64bit = "value.as_double()" encode_func = "encode_double" wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec + maximum_size = 8 # 8 bytes for double def dump(self, name: str) -> str: o = f'sprintf(buffer, "%g", {name});\n' @@ -299,6 +304,7 @@ class FloatType(TypeInfo): decode_32bit = "value.as_float()" encode_func = "encode_float" wire_type = WireType.FIXED32 # Uses wire type 5 + maximum_size = 4 # 4 bytes for float def dump(self, name: str) -> str: o = f'sprintf(buffer, "%g", {name});\n' @@ -318,6 +324,7 @@ class Int64Type(TypeInfo): decode_varint = "value.as_int64()" encode_func = "encode_int64" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 10 # Maximum 10 bytes for varint64 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' @@ -337,6 +344,7 @@ class UInt64Type(TypeInfo): decode_varint = "value.as_uint64()" encode_func = "encode_uint64" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 10 # Maximum 10 bytes for varint64 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%llu", {name});\n' @@ -356,6 +364,7 @@ class Int32Type(TypeInfo): decode_varint = "value.as_int32()" encode_func = "encode_int32" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 5 # Maximum 5 bytes for varint32 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' @@ -375,6 +384,7 @@ class Fixed64Type(TypeInfo): decode_64bit = "value.as_fixed64()" encode_func = "encode_fixed64" wire_type = WireType.FIXED64 # Uses wire type 1 + maximum_size = 8 # 8 bytes for fixed64 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%llu", {name});\n' @@ -394,6 +404,7 @@ class Fixed32Type(TypeInfo): decode_32bit = "value.as_fixed32()" encode_func = "encode_fixed32" wire_type = WireType.FIXED32 # Uses wire type 5 + maximum_size = 4 # 4 bytes for fixed32 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRIu32, {name});\n' @@ -413,6 +424,7 @@ class BoolType(TypeInfo): decode_varint = "value.as_bool()" encode_func = "encode_bool" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 1 # 1 byte for bool def dump(self, name: str) -> str: o = f"out.append(YESNO({name}));" @@ -506,6 +518,7 @@ class UInt32Type(TypeInfo): decode_varint = "value.as_uint32()" encode_func = "encode_uint32" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 5 # Maximum 5 bytes for varint32 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRIu32, {name});\n' @@ -520,6 +533,8 @@ class UInt32Type(TypeInfo): @register_type(14) class EnumType(TypeInfo): + maximum_size = 5 # Maximum 5 bytes for enum (same as uint32) + @property def cpp_type(self) -> str: return f"enums::{self._field.type_name[1:]}" @@ -552,6 +567,7 @@ class SFixed32Type(TypeInfo): decode_32bit = "value.as_sfixed32()" encode_func = "encode_sfixed32" wire_type = WireType.FIXED32 # Uses wire type 5 + maximum_size = 4 # 4 bytes for sfixed32 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' @@ -571,6 +587,7 @@ class SFixed64Type(TypeInfo): decode_64bit = "value.as_sfixed64()" encode_func = "encode_sfixed64" wire_type = WireType.FIXED64 # Uses wire type 1 + maximum_size = 8 # 8 bytes for sfixed64 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' @@ -590,6 +607,7 @@ class SInt32Type(TypeInfo): decode_varint = "value.as_sint32()" encode_func = "encode_sint32" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 5 # Maximum 5 bytes for sint32 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' @@ -609,6 +627,7 @@ class SInt64Type(TypeInfo): decode_varint = "value.as_sint64()" encode_func = "encode_sint64" wire_type = WireType.VARINT # Uses wire type 0 + maximum_size = 10 # Maximum 10 bytes for sint64 def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' @@ -622,6 +641,8 @@ class SInt64Type(TypeInfo): class RepeatedTypeInfo(TypeInfo): + maximum_size = -1 # Repeated fields are variable length + def __init__(self, field: descriptor.FieldDescriptorProto) -> None: super().__init__(field) self._ti: TypeInfo = TYPE_INFO[field.type](field) @@ -762,30 +783,25 @@ def build_enum_type(desc) -> tuple[str, str]: return out, cpp -# Maximum sizes for fixed-size protobuf types -PROTOBUF_TYPE_SIZES = { - 1: 8, # DOUBLE - 2: 4, # FLOAT - 3: 10, # INT64 (max varint size) - 4: 10, # UINT64 (max varint size) - 5: 5, # INT32 (max varint size) - 6: 8, # FIXED64 - 7: 4, # FIXED32 - 8: 1, # BOOL - # 9: STRING (variable) - # 10: GROUP (deprecated) - # 11: MESSAGE (variable) - # 12: BYTES (variable) - 13: 5, # UINT32 (max varint size) - 14: 5, # ENUM (same as uint32) - 15: 4, # SFIXED32 - 16: 8, # SFIXED64 - 17: 5, # SINT32 (max varint size) - 18: 10, # SINT64 (max varint size) -} +def get_type_sizes(): + """Dynamically build type size dictionaries from registered types.""" + # Build PROTOBUF_TYPE_SIZES from TYPE_INFO + type_sizes = {} + variable_length_types = set() -# Variable length types that cannot have fixed size -VARIABLE_LENGTH_TYPES = {9, 11, 12} # STRING, MESSAGE, BYTES + for type_id, type_class in TYPE_INFO.items(): + if hasattr(type_class, "maximum_size"): + if type_class.maximum_size == -1: + variable_length_types.add(type_id) + else: + type_sizes[type_id] = type_class.maximum_size + + return type_sizes, variable_length_types + + +# Build the type size mappings once when the module loads +# This avoids recalculating these dictionaries for every message +PROTOBUF_TYPE_SIZES, VARIABLE_LENGTH_TYPES = get_type_sizes() def calculate_fixed_message_size(desc: descriptor.DescriptorProto) -> int: