tweak
This commit is contained in:
@@ -90,6 +90,10 @@ def force_str(force: bool) -> str:
|
||||
class TypeInfo(ABC):
|
||||
"""Base class for all type information."""
|
||||
|
||||
# Maximum size for fixed-size types (in bytes)
|
||||
# -1 for variable length types
|
||||
maximum_size = -1
|
||||
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
self._field = field
|
||||
|
||||
@@ -280,6 +284,7 @@ class DoubleType(TypeInfo):
|
||||
decode_64bit = "value.as_double()"
|
||||
encode_func = "encode_double"
|
||||
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
|
||||
maximum_size = 8 # 8 bytes for double
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
@@ -299,6 +304,7 @@ class FloatType(TypeInfo):
|
||||
decode_32bit = "value.as_float()"
|
||||
encode_func = "encode_float"
|
||||
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||
maximum_size = 4 # 4 bytes for float
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
@@ -318,6 +324,7 @@ class Int64Type(TypeInfo):
|
||||
decode_varint = "value.as_int64()"
|
||||
encode_func = "encode_int64"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 10 # Maximum 10 bytes for varint64
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
@@ -337,6 +344,7 @@ class UInt64Type(TypeInfo):
|
||||
decode_varint = "value.as_uint64()"
|
||||
encode_func = "encode_uint64"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 10 # Maximum 10 bytes for varint64
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||
@@ -356,6 +364,7 @@ class Int32Type(TypeInfo):
|
||||
decode_varint = "value.as_int32()"
|
||||
encode_func = "encode_int32"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 5 # Maximum 5 bytes for varint32
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
@@ -375,6 +384,7 @@ class Fixed64Type(TypeInfo):
|
||||
decode_64bit = "value.as_fixed64()"
|
||||
encode_func = "encode_fixed64"
|
||||
wire_type = WireType.FIXED64 # Uses wire type 1
|
||||
maximum_size = 8 # 8 bytes for fixed64
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||
@@ -394,6 +404,7 @@ class Fixed32Type(TypeInfo):
|
||||
decode_32bit = "value.as_fixed32()"
|
||||
encode_func = "encode_fixed32"
|
||||
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||
maximum_size = 4 # 4 bytes for fixed32
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||
@@ -413,6 +424,7 @@ class BoolType(TypeInfo):
|
||||
decode_varint = "value.as_bool()"
|
||||
encode_func = "encode_bool"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 1 # 1 byte for bool
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"out.append(YESNO({name}));"
|
||||
@@ -506,6 +518,7 @@ class UInt32Type(TypeInfo):
|
||||
decode_varint = "value.as_uint32()"
|
||||
encode_func = "encode_uint32"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 5 # Maximum 5 bytes for varint32
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||
@@ -520,6 +533,8 @@ class UInt32Type(TypeInfo):
|
||||
|
||||
@register_type(14)
|
||||
class EnumType(TypeInfo):
|
||||
maximum_size = 5 # Maximum 5 bytes for enum (same as uint32)
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
return f"enums::{self._field.type_name[1:]}"
|
||||
@@ -552,6 +567,7 @@ class SFixed32Type(TypeInfo):
|
||||
decode_32bit = "value.as_sfixed32()"
|
||||
encode_func = "encode_sfixed32"
|
||||
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||
maximum_size = 4 # 4 bytes for sfixed32
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
@@ -571,6 +587,7 @@ class SFixed64Type(TypeInfo):
|
||||
decode_64bit = "value.as_sfixed64()"
|
||||
encode_func = "encode_sfixed64"
|
||||
wire_type = WireType.FIXED64 # Uses wire type 1
|
||||
maximum_size = 8 # 8 bytes for sfixed64
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
@@ -590,6 +607,7 @@ class SInt32Type(TypeInfo):
|
||||
decode_varint = "value.as_sint32()"
|
||||
encode_func = "encode_sint32"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 5 # Maximum 5 bytes for sint32
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
@@ -609,6 +627,7 @@ class SInt64Type(TypeInfo):
|
||||
decode_varint = "value.as_sint64()"
|
||||
encode_func = "encode_sint64"
|
||||
wire_type = WireType.VARINT # Uses wire type 0
|
||||
maximum_size = 10 # Maximum 10 bytes for sint64
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
@@ -622,6 +641,8 @@ class SInt64Type(TypeInfo):
|
||||
|
||||
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
maximum_size = -1 # Repeated fields are variable length
|
||||
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
super().__init__(field)
|
||||
self._ti: TypeInfo = TYPE_INFO[field.type](field)
|
||||
@@ -762,30 +783,25 @@ def build_enum_type(desc) -> tuple[str, str]:
|
||||
return out, cpp
|
||||
|
||||
|
||||
# Maximum sizes for fixed-size protobuf types
|
||||
PROTOBUF_TYPE_SIZES = {
|
||||
1: 8, # DOUBLE
|
||||
2: 4, # FLOAT
|
||||
3: 10, # INT64 (max varint size)
|
||||
4: 10, # UINT64 (max varint size)
|
||||
5: 5, # INT32 (max varint size)
|
||||
6: 8, # FIXED64
|
||||
7: 4, # FIXED32
|
||||
8: 1, # BOOL
|
||||
# 9: STRING (variable)
|
||||
# 10: GROUP (deprecated)
|
||||
# 11: MESSAGE (variable)
|
||||
# 12: BYTES (variable)
|
||||
13: 5, # UINT32 (max varint size)
|
||||
14: 5, # ENUM (same as uint32)
|
||||
15: 4, # SFIXED32
|
||||
16: 8, # SFIXED64
|
||||
17: 5, # SINT32 (max varint size)
|
||||
18: 10, # SINT64 (max varint size)
|
||||
}
|
||||
def get_type_sizes():
|
||||
"""Dynamically build type size dictionaries from registered types."""
|
||||
# Build PROTOBUF_TYPE_SIZES from TYPE_INFO
|
||||
type_sizes = {}
|
||||
variable_length_types = set()
|
||||
|
||||
# Variable length types that cannot have fixed size
|
||||
VARIABLE_LENGTH_TYPES = {9, 11, 12} # STRING, MESSAGE, BYTES
|
||||
for type_id, type_class in TYPE_INFO.items():
|
||||
if hasattr(type_class, "maximum_size"):
|
||||
if type_class.maximum_size == -1:
|
||||
variable_length_types.add(type_id)
|
||||
else:
|
||||
type_sizes[type_id] = type_class.maximum_size
|
||||
|
||||
return type_sizes, variable_length_types
|
||||
|
||||
|
||||
# Build the type size mappings once when the module loads
|
||||
# This avoids recalculating these dictionaries for every message
|
||||
PROTOBUF_TYPE_SIZES, VARIABLE_LENGTH_TYPES = get_type_sizes()
|
||||
|
||||
|
||||
def calculate_fixed_message_size(desc: descriptor.DescriptorProto) -> int:
|
||||
|
||||
Reference in New Issue
Block a user