Files
pig-house-controller/main/proto/client_pb.py
2025-10-08 15:58:51 +08:00

379 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
根据client.proto生成的解析代码
适用于ESP32 MicroPython环境
"""
import struct
# --- Protobuf基础类型辅助函数 ---
def encode_varint(value):
"""编码varint整数"""
buf = bytearray()
while value >= 0x80:
buf.append((value & 0x7F) | 0x80)
value >>= 7
buf.append(value & 0x7F)
return buf
def decode_varint(buf, pos=0):
"""解码varint整数"""
result = 0
shift = 0
while pos < len(buf):
byte = buf[pos]
pos += 1
result |= (byte & 0x7F) << shift
if not (byte & 0x80):
break
shift += 7
return result, pos
def encode_string(value):
"""编码字符串"""
value_bytes = value.encode('utf-8')
length = encode_varint(len(value_bytes))
return length + value_bytes
def decode_string(buf, pos=0):
"""解码字符串"""
length, pos = decode_varint(buf, pos)
value = buf[pos:pos+length].decode('utf-8')
pos += length
return value, pos
# --- 消息编码/解码函数 ---
def encode_raw_485_command(bus_number, command_bytes):
"""
编码Raw485Command消息
Args:
bus_number (int): 总线号
command_bytes (bytes): 原始485指令
Returns:
bytearray: 编码后的数据
"""
result = bytearray()
# bus_number (field 1, wire type 0)
result.extend(encode_varint((1 << 3) | 0))
result.extend(encode_varint(bus_number))
# command_bytes (field 2, wire type 2)
result.extend(encode_varint((2 << 3) | 2))
result.extend(encode_varint(len(command_bytes)))
result.extend(command_bytes)
return result
def decode_raw_485_command(buf):
"""
解码Raw485Command消息
Args:
buf (bytes): 编码后的数据
Returns:
dict: 解码后的消息
"""
result = {}
pos = 0
while pos < len(buf):
tag, pos = decode_varint(buf, pos)
field_number = tag >> 3
wire_type = tag & 0x07
if field_number == 1: # bus_number
if wire_type == 0:
value, pos = decode_varint(buf, pos)
result['bus_number'] = value
elif field_number == 2: # command_bytes
if wire_type == 2:
length, pos = decode_varint(buf, pos)
value = buf[pos:pos+length]
pos += length
result['command_bytes'] = value
else:
# 跳过未知字段
if wire_type == 0: _, pos = decode_varint(buf, pos)
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
else: pos += 1
return result
def encode_collect_task(command_msg):
"""
编码CollectTask消息
Args:
command_msg (dict): Raw485Command消息字典
Returns:
bytearray: 编码后的数据
"""
result = bytearray()
# command (field 2, wire type 2)
encoded_command = encode_raw_485_command(command_msg['bus_number'], command_msg['command_bytes'])
result.extend(encode_varint((2 << 3) | 2))
result.extend(encode_varint(len(encoded_command)))
result.extend(encoded_command)
return result
def decode_collect_task(buf):
"""
解码CollectTask消息
Args:
buf (bytes): 编码后的数据
Returns:
dict: 解码后的消息
"""
result = {}
pos = 0
while pos < len(buf):
tag, pos = decode_varint(buf, pos)
field_number = tag >> 3
wire_type = tag & 0x07
if field_number == 2: # command
if wire_type == 2:
length, pos = decode_varint(buf, pos)
value_buf = buf[pos:pos+length]
pos += length
result['command'] = decode_raw_485_command(value_buf)
else:
if wire_type == 0: _, pos = decode_varint(buf, pos)
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
else: pos += 1
return result
def encode_batch_collect_command(correlation_id, tasks):
"""
编码BatchCollectCommand消息
Args:
correlation_id (str): 关联ID
tasks (list): CollectTask消息字典列表
Returns:
bytearray: 编码后的数据
"""
result = bytearray()
# correlation_id (field 1, wire type 2)
result.extend(encode_varint((1 << 3) | 2))
result.extend(encode_string(correlation_id))
# tasks (field 2, wire type 2) - repeated
for task in tasks:
encoded_task = encode_collect_task(task['command'])
result.extend(encode_varint((2 << 3) | 2))
result.extend(encode_varint(len(encoded_task)))
result.extend(encoded_task)
return result
def decode_batch_collect_command(buf):
"""
解码BatchCollectCommand消息
Args:
buf (bytes): 编码后的数据
Returns:
dict: 解码后的消息
"""
result = {'tasks': []}
pos = 0
while pos < len(buf):
tag, pos = decode_varint(buf, pos)
field_number = tag >> 3
wire_type = tag & 0x07
if field_number == 1: # correlation_id
if wire_type == 2:
value, pos = decode_string(buf, pos)
result['correlation_id'] = value
elif field_number == 2: # tasks (repeated)
if wire_type == 2:
length, pos = decode_varint(buf, pos)
value_buf = buf[pos:pos+length]
pos += length
result['tasks'].append(decode_collect_task(value_buf))
else:
if wire_type == 0: _, pos = decode_varint(buf, pos)
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
else: pos += 1
return result
def encode_collect_result(correlation_id, values):
"""
编码CollectResult消息
Args:
correlation_id (str): 关联ID
values (list): 采集值列表 (float)
Returns:
bytearray: 编码后的数据
"""
result = bytearray()
# correlation_id (field 1, wire type 2)
result.extend(encode_varint((1 << 3) | 2))
result.extend(encode_string(correlation_id))
# values (field 2, wire type 5) - repeated fixed32
for value in values:
result.extend(encode_varint((2 << 3) | 5)) # Tag for fixed32
result.extend(struct.pack('<f', value)) # 小端序浮点数
return result
def decode_collect_result(buf):
"""
解码CollectResult消息
Args:
buf (bytes): 编码后的数据
Returns:
dict: 解码后的消息
"""
result = {'values': []}
pos = 0
while pos < len(buf):
tag, pos = decode_varint(buf, pos)
field_number = tag >> 3
wire_type = tag & 0x07
if field_number == 1: # correlation_id
if wire_type == 2:
value, pos = decode_string(buf, pos)
result['correlation_id'] = value
elif field_number == 2: # values (repeated)
if wire_type == 5: # fixed32
value = struct.unpack('<f', buf[pos:pos+4])[0]
pos += 4
result['values'].append(value)
else:
if wire_type == 0: _, pos = decode_varint(buf, pos)
elif wire_type == 5: pos += 4 # fixed32
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
else: pos += 1
return result
def encode_instruction(payload_type, payload_data):
"""
编码Instruction消息 (包含oneof字段)
Args:
payload_type (str): oneof字段的类型 ('raw_485_command', 'batch_collect_command', 'collect_result')
payload_data (dict): 对应类型的消息字典
Returns:
bytearray: 编码后的数据
"""
result = bytearray()
encoded_payload = bytearray()
if payload_type == 'raw_485_command':
encoded_payload = encode_raw_485_command(payload_data['bus_number'], payload_data['command_bytes'])
result.extend(encode_varint((1 << 3) | 2)) # field 1, wire type 2
elif payload_type == 'batch_collect_command':
encoded_payload = encode_batch_collect_command(payload_data['correlation_id'], payload_data['tasks'])
result.extend(encode_varint((2 << 3) | 2)) # field 2, wire type 2
elif payload_type == 'collect_result':
encoded_payload = encode_collect_result(payload_data['correlation_id'], payload_data['values'])
result.extend(encode_varint((3 << 3) | 2)) # field 3, wire type 2
else:
raise ValueError("未知的指令负载类型")
result.extend(encode_varint(len(encoded_payload)))
result.extend(encoded_payload)
return result
def decode_instruction(buf):
"""
解码Instruction消息
Args:
buf (bytes): 编码后的数据
Returns:
dict: 解码后的消息
"""
result = {}
pos = 0
while pos < len(buf):
tag, pos = decode_varint(buf, pos)
field_number = tag >> 3
wire_type = tag & 0x07
if wire_type == 2: # 所有oneof字段都使用长度分隔类型
length, pos = decode_varint(buf, pos)
value_buf = buf[pos:pos+length]
pos += length
if field_number == 1: # raw_485_command
result['raw_485_command'] = decode_raw_485_command(value_buf)
elif field_number == 2: # batch_collect_command
result['batch_collect_command'] = decode_batch_collect_command(value_buf)
elif field_number == 3: # collect_result
result['collect_result'] = decode_collect_result(value_buf)
else:
# 跳过未知字段
if wire_type == 0: _, pos = decode_varint(buf, pos)
elif wire_type == 5: pos += 4
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
else: pos += 1
return result
# --- 单元测试与使用范例 ---
if __name__ == "__main__":
print("--- 测试 Raw485Command ---")
raw_cmd_data = {'bus_number': 1, 'command_bytes': b'\x01\x03\x00\x00\x00\x02\xc4\x0b'}
encoded_raw_cmd = encode_raw_485_command(raw_cmd_data['bus_number'], raw_cmd_data['command_bytes'])
print(f"编码后 Raw485Command: {encoded_raw_cmd.hex()}")
decoded_raw_cmd = decode_raw_485_command(encoded_raw_cmd)
print(f"解码后 Raw485Command: {decoded_raw_cmd}")
assert decoded_raw_cmd == raw_cmd_data
print("\n--- 测试 CollectTask ---")
collect_task_data = {'command': raw_cmd_data}
encoded_collect_task = encode_collect_task(collect_task_data['command'])
print(f"编码后 CollectTask: {encoded_collect_task.hex()}")
decoded_collect_task = decode_collect_task(encoded_collect_task)
print(f"解码后 CollectTask: {decoded_collect_task}")
assert decoded_collect_task == collect_task_data
print("\n--- 测试 BatchCollectCommand ---")
batch_collect_data = {
'correlation_id': 'abc-123',
'tasks': [
{'command': {'bus_number': 1, 'command_bytes': b'\x01\x03\x00\x00\x00\x02\xc4\x0b'}},
{'command': {'bus_number': 2, 'command_bytes': b'\x02\x03\x00\x01\x00\x01\xd5\xfa'}}
]
}
encoded_batch_collect = encode_batch_collect_command(batch_collect_data['correlation_id'], batch_collect_data['tasks'])
print(f"编码后 BatchCollectCommand: {encoded_batch_collect.hex()}")
decoded_batch_collect = decode_batch_collect_command(encoded_batch_collect)
print(f"解码后 BatchCollectCommand: {decoded_batch_collect}")
assert decoded_batch_collect == batch_collect_data
print("\n--- 测试 CollectResult ---")
collect_result_data = {
'correlation_id': 'res-456',
'values': [12.34, 56.78, 90.12]
}
encoded_collect_result = encode_collect_result(collect_result_data['correlation_id'], collect_result_data['values'])
print(f"编码后 CollectResult: {encoded_collect_result.hex()}")
decoded_collect_result = decode_collect_result(encoded_collect_result)
print(f"解码后 CollectResult: {decoded_collect_result}")
# 由于32位浮点数精度问题直接比较可能会失败此处设置一个合理的容忍度
assert decoded_collect_result['correlation_id'] == collect_result_data['correlation_id']
for i in range(len(collect_result_data['values'])):
assert abs(decoded_collect_result['values'][i] - collect_result_data['values'][i]) < 1e-5 # 已放宽容忍度
print("\n--- 测试 Instruction (内含Raw485Command) ---")
instruction_raw_485 = encode_instruction('raw_485_command', raw_cmd_data)
print(f"编码后 Instruction (Raw485Command): {instruction_raw_485.hex()}")
decoded_instruction_raw_485 = decode_instruction(instruction_raw_485)
print(f"解码后 Instruction (Raw485Command): {decoded_instruction_raw_485}")
assert decoded_instruction_raw_485['raw_485_command'] == raw_cmd_data
print("\n--- 测试 Instruction (内含BatchCollectCommand) ---")
instruction_batch_collect = encode_instruction('batch_collect_command', batch_collect_data)
print(f"编码后 Instruction (BatchCollectCommand): {instruction_batch_collect.hex()}")
decoded_instruction_batch_collect = decode_instruction(instruction_batch_collect)
print(f"解码后 Instruction (BatchCollectCommand): {decoded_instruction_batch_collect}")
assert decoded_instruction_batch_collect['batch_collect_command']['correlation_id'] == batch_collect_data['correlation_id']
assert len(decoded_instruction_batch_collect['batch_collect_command']['tasks']) == len(batch_collect_data['tasks'])
print("\n--- 测试 Instruction (内含CollectResult) ---")
instruction_collect_result = encode_instruction('collect_result', collect_result_data)
print(f"编码后 Instruction (CollectResult): {instruction_collect_result.hex()}")
decoded_instruction_collect_result = decode_instruction(instruction_collect_result)
print(f"解码后 Instruction (CollectResult): {decoded_instruction_collect_result}")
assert decoded_instruction_collect_result['collect_result']['correlation_id'] == collect_result_data['correlation_id']
for i in range(len(collect_result_data['values'])):
assert abs(decoded_instruction_collect_result['collect_result']['values'][i] - collect_result_data['values'][i]) < 1e-5
print("\n所有测试均已通过!")