This commit is contained in:
J. Nick Koston
2025-05-16 14:06:54 -04:00
parent 8f152c5a0a
commit 3107e2f85e

View File

@@ -90,6 +90,10 @@ def force_str(force: bool) -> str:
class TypeInfo(ABC):
"""Base class for all type information."""
# Maximum size for fixed-size types (in bytes)
# -1 for variable length types
maximum_size = -1
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
self._field = field
@@ -280,6 +284,7 @@ class DoubleType(TypeInfo):
decode_64bit = "value.as_double()"
encode_func = "encode_double"
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
maximum_size = 8 # 8 bytes for double
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n'
@@ -299,6 +304,7 @@ class FloatType(TypeInfo):
decode_32bit = "value.as_float()"
encode_func = "encode_float"
wire_type = WireType.FIXED32 # Uses wire type 5
maximum_size = 4 # 4 bytes for float
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n'
@@ -318,6 +324,7 @@ class Int64Type(TypeInfo):
decode_varint = "value.as_int64()"
encode_func = "encode_int64"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 10 # Maximum 10 bytes for varint64
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
@@ -337,6 +344,7 @@ class UInt64Type(TypeInfo):
decode_varint = "value.as_uint64()"
encode_func = "encode_uint64"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 10 # Maximum 10 bytes for varint64
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n'
@@ -356,6 +364,7 @@ class Int32Type(TypeInfo):
decode_varint = "value.as_int32()"
encode_func = "encode_int32"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 5 # Maximum 5 bytes for varint32
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
@@ -375,6 +384,7 @@ class Fixed64Type(TypeInfo):
decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
maximum_size = 8 # 8 bytes for fixed64
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n'
@@ -394,6 +404,7 @@ class Fixed32Type(TypeInfo):
decode_32bit = "value.as_fixed32()"
encode_func = "encode_fixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
maximum_size = 4 # 4 bytes for fixed32
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
@@ -413,6 +424,7 @@ class BoolType(TypeInfo):
decode_varint = "value.as_bool()"
encode_func = "encode_bool"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 1 # 1 byte for bool
def dump(self, name: str) -> str:
o = f"out.append(YESNO({name}));"
@@ -506,6 +518,7 @@ class UInt32Type(TypeInfo):
decode_varint = "value.as_uint32()"
encode_func = "encode_uint32"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 5 # Maximum 5 bytes for varint32
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
@@ -520,6 +533,8 @@ class UInt32Type(TypeInfo):
@register_type(14)
class EnumType(TypeInfo):
maximum_size = 5 # Maximum 5 bytes for enum (same as uint32)
@property
def cpp_type(self) -> str:
return f"enums::{self._field.type_name[1:]}"
@@ -552,6 +567,7 @@ class SFixed32Type(TypeInfo):
decode_32bit = "value.as_sfixed32()"
encode_func = "encode_sfixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
maximum_size = 4 # 4 bytes for sfixed32
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
@@ -571,6 +587,7 @@ class SFixed64Type(TypeInfo):
decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
maximum_size = 8 # 8 bytes for sfixed64
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
@@ -590,6 +607,7 @@ class SInt32Type(TypeInfo):
decode_varint = "value.as_sint32()"
encode_func = "encode_sint32"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 5 # Maximum 5 bytes for sint32
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
@@ -609,6 +627,7 @@ class SInt64Type(TypeInfo):
decode_varint = "value.as_sint64()"
encode_func = "encode_sint64"
wire_type = WireType.VARINT # Uses wire type 0
maximum_size = 10 # Maximum 10 bytes for sint64
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
@@ -622,6 +641,8 @@ class SInt64Type(TypeInfo):
class RepeatedTypeInfo(TypeInfo):
maximum_size = -1 # Repeated fields are variable length
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
super().__init__(field)
self._ti: TypeInfo = TYPE_INFO[field.type](field)
@@ -762,30 +783,25 @@ def build_enum_type(desc) -> tuple[str, str]:
return out, cpp
# Maximum sizes for fixed-size protobuf types
PROTOBUF_TYPE_SIZES = {
1: 8, # DOUBLE
2: 4, # FLOAT
3: 10, # INT64 (max varint size)
4: 10, # UINT64 (max varint size)
5: 5, # INT32 (max varint size)
6: 8, # FIXED64
7: 4, # FIXED32
8: 1, # BOOL
# 9: STRING (variable)
# 10: GROUP (deprecated)
# 11: MESSAGE (variable)
# 12: BYTES (variable)
13: 5, # UINT32 (max varint size)
14: 5, # ENUM (same as uint32)
15: 4, # SFIXED32
16: 8, # SFIXED64
17: 5, # SINT32 (max varint size)
18: 10, # SINT64 (max varint size)
}
def get_type_sizes():
"""Dynamically build type size dictionaries from registered types."""
# Build PROTOBUF_TYPE_SIZES from TYPE_INFO
type_sizes = {}
variable_length_types = set()
# Variable length types that cannot have fixed size
VARIABLE_LENGTH_TYPES = {9, 11, 12} # STRING, MESSAGE, BYTES
for type_id, type_class in TYPE_INFO.items():
if hasattr(type_class, "maximum_size"):
if type_class.maximum_size == -1:
variable_length_types.add(type_id)
else:
type_sizes[type_id] = type_class.maximum_size
return type_sizes, variable_length_types
# Build the type size mappings once when the module loads
# This avoids recalculating these dictionaries for every message
PROTOBUF_TYPE_SIZES, VARIABLE_LENGTH_TYPES = get_type_sizes()
def calculate_fixed_message_size(desc: descriptor.DescriptorProto) -> int: