379 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			379 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python
 | ||
| # -*- coding: utf-8 -*-
 | ||
| 
 | ||
| """
 | ||
| 根据client.proto生成的解析代码
 | ||
| 适用于ESP32 MicroPython环境
 | ||
| """
 | ||
| 
 | ||
| import struct
 | ||
| 
 | ||
| # --- Protobuf基础类型辅助函数 ---
 | ||
| 
 | ||
| def encode_varint(value):
 | ||
|     """编码varint整数"""
 | ||
|     buf = bytearray()
 | ||
|     while value >= 0x80:
 | ||
|         buf.append((value & 0x7F) | 0x80)
 | ||
|         value >>= 7
 | ||
|     buf.append(value & 0x7F)
 | ||
|     return buf
 | ||
| 
 | ||
| def decode_varint(buf, pos=0):
 | ||
|     """解码varint整数"""
 | ||
|     result = 0
 | ||
|     shift = 0
 | ||
|     while pos < len(buf):
 | ||
|         byte = buf[pos]
 | ||
|         pos += 1
 | ||
|         result |= (byte & 0x7F) << shift
 | ||
|         if not (byte & 0x80):
 | ||
|             break
 | ||
|         shift += 7
 | ||
|     return result, pos
 | ||
| 
 | ||
| def encode_string(value):
 | ||
|     """编码字符串"""
 | ||
|     value_bytes = value.encode('utf-8')
 | ||
|     length = encode_varint(len(value_bytes))
 | ||
|     return length + value_bytes
 | ||
| 
 | ||
| def decode_string(buf, pos=0):
 | ||
|     """解码字符串"""
 | ||
|     length, pos = decode_varint(buf, pos)
 | ||
|     value = buf[pos:pos+length].decode('utf-8')
 | ||
|     pos += length
 | ||
|     return value, pos
 | ||
| 
 | ||
| # --- 消息编码/解码函数 ---
 | ||
| 
 | ||
| def encode_raw_485_command(bus_number, command_bytes):
 | ||
|     """
 | ||
|     编码Raw485Command消息
 | ||
|     Args:
 | ||
|         bus_number (int): 总线号
 | ||
|         command_bytes (bytes): 原始485指令
 | ||
|     Returns:
 | ||
|         bytearray: 编码后的数据
 | ||
|     """
 | ||
|     result = bytearray()
 | ||
|     # bus_number (field 1, wire type 0)
 | ||
|     result.extend(encode_varint((1 << 3) | 0))
 | ||
|     result.extend(encode_varint(bus_number))
 | ||
|     # command_bytes (field 2, wire type 2)
 | ||
|     result.extend(encode_varint((2 << 3) | 2))
 | ||
|     result.extend(encode_varint(len(command_bytes)))
 | ||
|     result.extend(command_bytes)
 | ||
|     return result
 | ||
| 
 | ||
| def decode_raw_485_command(buf):
 | ||
|     """
 | ||
|     解码Raw485Command消息
 | ||
|     Args:
 | ||
|         buf (bytes): 编码后的数据
 | ||
|     Returns:
 | ||
|         dict: 解码后的消息
 | ||
|     """
 | ||
|     result = {}
 | ||
|     pos = 0
 | ||
|     while pos < len(buf):
 | ||
|         tag, pos = decode_varint(buf, pos)
 | ||
|         field_number = tag >> 3
 | ||
|         wire_type = tag & 0x07
 | ||
| 
 | ||
|         if field_number == 1:  # bus_number
 | ||
|             if wire_type == 0:
 | ||
|                 value, pos = decode_varint(buf, pos)
 | ||
|                 result['bus_number'] = value
 | ||
|         elif field_number == 2:  # command_bytes
 | ||
|             if wire_type == 2:
 | ||
|                 length, pos = decode_varint(buf, pos)
 | ||
|                 value = buf[pos:pos+length]
 | ||
|                 pos += length
 | ||
|                 result['command_bytes'] = value
 | ||
|         else:
 | ||
|             # 跳过未知字段
 | ||
|             if wire_type == 0: _, pos = decode_varint(buf, pos)
 | ||
|             elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
 | ||
|             else: pos += 1
 | ||
|     return result
 | ||
| 
 | ||
| def encode_collect_task(command_msg):
 | ||
|     """
 | ||
|     编码CollectTask消息
 | ||
|     Args:
 | ||
|         command_msg (dict): Raw485Command消息字典
 | ||
|     Returns:
 | ||
|         bytearray: 编码后的数据
 | ||
|     """
 | ||
|     result = bytearray()
 | ||
|     # command (field 2, wire type 2)
 | ||
|     encoded_command = encode_raw_485_command(command_msg['bus_number'], command_msg['command_bytes'])
 | ||
|     result.extend(encode_varint((2 << 3) | 2))
 | ||
|     result.extend(encode_varint(len(encoded_command)))
 | ||
|     result.extend(encoded_command)
 | ||
|     return result
 | ||
| 
 | ||
| def decode_collect_task(buf):
 | ||
|     """
 | ||
|     解码CollectTask消息
 | ||
|     Args:
 | ||
|         buf (bytes): 编码后的数据
 | ||
|     Returns:
 | ||
|         dict: 解码后的消息
 | ||
|     """
 | ||
|     result = {}
 | ||
|     pos = 0
 | ||
|     while pos < len(buf):
 | ||
|         tag, pos = decode_varint(buf, pos)
 | ||
|         field_number = tag >> 3
 | ||
|         wire_type = tag & 0x07
 | ||
| 
 | ||
|         if field_number == 2:  # command
 | ||
|             if wire_type == 2:
 | ||
|                 length, pos = decode_varint(buf, pos)
 | ||
|                 value_buf = buf[pos:pos+length]
 | ||
|                 pos += length
 | ||
|                 result['command'] = decode_raw_485_command(value_buf)
 | ||
|         else:
 | ||
|             if wire_type == 0: _, pos = decode_varint(buf, pos)
 | ||
|             elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
 | ||
|             else: pos += 1
 | ||
|     return result
 | ||
| 
 | ||
| def encode_batch_collect_command(correlation_id, tasks):
 | ||
|     """
 | ||
|     编码BatchCollectCommand消息
 | ||
|     Args:
 | ||
|         correlation_id (str): 关联ID
 | ||
|         tasks (list): CollectTask消息字典列表
 | ||
|     Returns:
 | ||
|         bytearray: 编码后的数据
 | ||
|     """
 | ||
|     result = bytearray()
 | ||
|     # correlation_id (field 1, wire type 2)
 | ||
|     result.extend(encode_varint((1 << 3) | 2))
 | ||
|     result.extend(encode_string(correlation_id))
 | ||
|     # tasks (field 2, wire type 2) - repeated
 | ||
|     for task in tasks:
 | ||
|         encoded_task = encode_collect_task(task['command'])
 | ||
|         result.extend(encode_varint((2 << 3) | 2))
 | ||
|         result.extend(encode_varint(len(encoded_task)))
 | ||
|         result.extend(encoded_task)
 | ||
|     return result
 | ||
| 
 | ||
| def decode_batch_collect_command(buf):
 | ||
|     """
 | ||
|     解码BatchCollectCommand消息
 | ||
|     Args:
 | ||
|         buf (bytes): 编码后的数据
 | ||
|     Returns:
 | ||
|         dict: 解码后的消息
 | ||
|     """
 | ||
|     result = {'tasks': []}
 | ||
|     pos = 0
 | ||
|     while pos < len(buf):
 | ||
|         tag, pos = decode_varint(buf, pos)
 | ||
|         field_number = tag >> 3
 | ||
|         wire_type = tag & 0x07
 | ||
| 
 | ||
|         if field_number == 1:  # correlation_id
 | ||
|             if wire_type == 2:
 | ||
|                 value, pos = decode_string(buf, pos)
 | ||
|                 result['correlation_id'] = value
 | ||
|         elif field_number == 2:  # tasks (repeated)
 | ||
|             if wire_type == 2:
 | ||
|                 length, pos = decode_varint(buf, pos)
 | ||
|                 value_buf = buf[pos:pos+length]
 | ||
|                 pos += length
 | ||
|                 result['tasks'].append(decode_collect_task(value_buf))
 | ||
|         else:
 | ||
|             if wire_type == 0: _, pos = decode_varint(buf, pos)
 | ||
|             elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
 | ||
|             else: pos += 1
 | ||
|     return result
 | ||
| 
 | ||
| def encode_collect_result(correlation_id, values):
 | ||
|     """
 | ||
|     编码CollectResult消息
 | ||
|     Args:
 | ||
|         correlation_id (str): 关联ID
 | ||
|         values (list): 采集值列表 (float)
 | ||
|     Returns:
 | ||
|         bytearray: 编码后的数据
 | ||
|     """
 | ||
|     result = bytearray()
 | ||
|     # correlation_id (field 1, wire type 2)
 | ||
|     result.extend(encode_varint((1 << 3) | 2))
 | ||
|     result.extend(encode_string(correlation_id))
 | ||
|     # values (field 2, wire type 5) - repeated fixed32
 | ||
|     for value in values:
 | ||
|         result.extend(encode_varint((2 << 3) | 5))  # Tag for fixed32
 | ||
|         result.extend(struct.pack('<f', value))     # 小端序浮点数
 | ||
|     return result
 | ||
| 
 | ||
| def decode_collect_result(buf):
 | ||
|     """
 | ||
|     解码CollectResult消息
 | ||
|     Args:
 | ||
|         buf (bytes): 编码后的数据
 | ||
|     Returns:
 | ||
|         dict: 解码后的消息
 | ||
|     """
 | ||
|     result = {'values': []}
 | ||
|     pos = 0
 | ||
|     while pos < len(buf):
 | ||
|         tag, pos = decode_varint(buf, pos)
 | ||
|         field_number = tag >> 3
 | ||
|         wire_type = tag & 0x07
 | ||
| 
 | ||
|         if field_number == 1:  # correlation_id
 | ||
|             if wire_type == 2:
 | ||
|                 value, pos = decode_string(buf, pos)
 | ||
|                 result['correlation_id'] = value
 | ||
|         elif field_number == 2:  # values (repeated)
 | ||
|             if wire_type == 5:  # fixed32
 | ||
|                 value = struct.unpack('<f', buf[pos:pos+4])[0]
 | ||
|                 pos += 4
 | ||
|                 result['values'].append(value)
 | ||
|         else:
 | ||
|             if wire_type == 0: _, pos = decode_varint(buf, pos)
 | ||
|             elif wire_type == 5: pos += 4  # fixed32
 | ||
|             elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
 | ||
|             else: pos += 1
 | ||
|     return result
 | ||
| 
 | ||
| def encode_instruction(payload_type, payload_data):
 | ||
|     """
 | ||
|     编码Instruction消息 (包含oneof字段)
 | ||
|     Args:
 | ||
|         payload_type (str): oneof字段的类型 ('raw_485_command', 'batch_collect_command', 'collect_result')
 | ||
|         payload_data (dict): 对应类型的消息字典
 | ||
|     Returns:
 | ||
|         bytearray: 编码后的数据
 | ||
|     """
 | ||
|     result = bytearray()
 | ||
|     encoded_payload = bytearray()
 | ||
| 
 | ||
|     if payload_type == 'raw_485_command':
 | ||
|         encoded_payload = encode_raw_485_command(payload_data['bus_number'], payload_data['command_bytes'])
 | ||
|         result.extend(encode_varint((1 << 3) | 2))  # field 1, wire type 2
 | ||
|     elif payload_type == 'batch_collect_command':
 | ||
|         encoded_payload = encode_batch_collect_command(payload_data['correlation_id'], payload_data['tasks'])
 | ||
|         result.extend(encode_varint((2 << 3) | 2))  # field 2, wire type 2
 | ||
|     elif payload_type == 'collect_result':
 | ||
|         encoded_payload = encode_collect_result(payload_data['correlation_id'], payload_data['values'])
 | ||
|         result.extend(encode_varint((3 << 3) | 2))  # field 3, wire type 2
 | ||
|     else:
 | ||
|         raise ValueError("未知的指令负载类型")
 | ||
| 
 | ||
|     result.extend(encode_varint(len(encoded_payload)))
 | ||
|     result.extend(encoded_payload)
 | ||
|     return result
 | ||
| 
 | ||
| def decode_instruction(buf):
 | ||
|     """
 | ||
|     解码Instruction消息
 | ||
|     Args:
 | ||
|         buf (bytes): 编码后的数据
 | ||
|     Returns:
 | ||
|         dict: 解码后的消息
 | ||
|     """
 | ||
|     result = {}
 | ||
|     pos = 0
 | ||
|     while pos < len(buf):
 | ||
|         tag, pos = decode_varint(buf, pos)
 | ||
|         field_number = tag >> 3
 | ||
|         wire_type = tag & 0x07
 | ||
| 
 | ||
|         if wire_type == 2:  # 所有oneof字段都使用长度分隔类型
 | ||
|             length, pos = decode_varint(buf, pos)
 | ||
|             value_buf = buf[pos:pos+length]
 | ||
|             pos += length
 | ||
| 
 | ||
|             if field_number == 1:  # raw_485_command
 | ||
|                 result['raw_485_command'] = decode_raw_485_command(value_buf)
 | ||
|             elif field_number == 2:  # batch_collect_command
 | ||
|                 result['batch_collect_command'] = decode_batch_collect_command(value_buf)
 | ||
|             elif field_number == 3:  # collect_result
 | ||
|                 result['collect_result'] = decode_collect_result(value_buf)
 | ||
|         else:
 | ||
|             # 跳过未知字段
 | ||
|             if wire_type == 0: _, pos = decode_varint(buf, pos)
 | ||
|             elif wire_type == 5: pos += 4
 | ||
|             elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
 | ||
|             else: pos += 1
 | ||
|     return result
 | ||
| 
 | ||
| # --- 单元测试与使用范例 ---
 | ||
| if __name__ == "__main__":
 | ||
|     print("--- 测试 Raw485Command ---")
 | ||
|     raw_cmd_data = {'bus_number': 1, 'command_bytes': b'\x01\x03\x00\x00\x00\x02\xc4\x0b'}
 | ||
|     encoded_raw_cmd = encode_raw_485_command(raw_cmd_data['bus_number'], raw_cmd_data['command_bytes'])
 | ||
|     print(f"编码后 Raw485Command: {encoded_raw_cmd.hex()}")
 | ||
|     decoded_raw_cmd = decode_raw_485_command(encoded_raw_cmd)
 | ||
|     print(f"解码后 Raw485Command: {decoded_raw_cmd}")
 | ||
|     assert decoded_raw_cmd == raw_cmd_data
 | ||
| 
 | ||
|     print("\n--- 测试 CollectTask ---")
 | ||
|     collect_task_data = {'command': raw_cmd_data}
 | ||
|     encoded_collect_task = encode_collect_task(collect_task_data['command'])
 | ||
|     print(f"编码后 CollectTask: {encoded_collect_task.hex()}")
 | ||
|     decoded_collect_task = decode_collect_task(encoded_collect_task)
 | ||
|     print(f"解码后 CollectTask: {decoded_collect_task}")
 | ||
|     assert decoded_collect_task == collect_task_data
 | ||
| 
 | ||
|     print("\n--- 测试 BatchCollectCommand ---")
 | ||
|     batch_collect_data = {
 | ||
|         'correlation_id': 'abc-123',
 | ||
|         'tasks': [
 | ||
|             {'command': {'bus_number': 1, 'command_bytes': b'\x01\x03\x00\x00\x00\x02\xc4\x0b'}},
 | ||
|             {'command': {'bus_number': 2, 'command_bytes': b'\x02\x03\x00\x01\x00\x01\xd5\xfa'}}
 | ||
|         ]
 | ||
|     }
 | ||
|     encoded_batch_collect = encode_batch_collect_command(batch_collect_data['correlation_id'], batch_collect_data['tasks'])
 | ||
|     print(f"编码后 BatchCollectCommand: {encoded_batch_collect.hex()}")
 | ||
|     decoded_batch_collect = decode_batch_collect_command(encoded_batch_collect)
 | ||
|     print(f"解码后 BatchCollectCommand: {decoded_batch_collect}")
 | ||
|     assert decoded_batch_collect == batch_collect_data
 | ||
| 
 | ||
|     print("\n--- 测试 CollectResult ---")
 | ||
|     collect_result_data = {
 | ||
|         'correlation_id': 'res-456',
 | ||
|         'values': [12.34, 56.78, 90.12]
 | ||
|     }
 | ||
|     encoded_collect_result = encode_collect_result(collect_result_data['correlation_id'], collect_result_data['values'])
 | ||
|     print(f"编码后 CollectResult: {encoded_collect_result.hex()}")
 | ||
|     decoded_collect_result = decode_collect_result(encoded_collect_result)
 | ||
|     print(f"解码后 CollectResult: {decoded_collect_result}")
 | ||
|     # 由于32位浮点数精度问题,直接比较可能会失败,此处设置一个合理的容忍度
 | ||
|     assert decoded_collect_result['correlation_id'] == collect_result_data['correlation_id']
 | ||
|     for i in range(len(collect_result_data['values'])):
 | ||
|         assert abs(decoded_collect_result['values'][i] - collect_result_data['values'][i]) < 1e-5 # 已放宽容忍度
 | ||
| 
 | ||
|     print("\n--- 测试 Instruction (内含Raw485Command) ---")
 | ||
|     instruction_raw_485 = encode_instruction('raw_485_command', raw_cmd_data)
 | ||
|     print(f"编码后 Instruction (Raw485Command): {instruction_raw_485.hex()}")
 | ||
|     decoded_instruction_raw_485 = decode_instruction(instruction_raw_485)
 | ||
|     print(f"解码后 Instruction (Raw485Command): {decoded_instruction_raw_485}")
 | ||
|     assert decoded_instruction_raw_485['raw_485_command'] == raw_cmd_data
 | ||
| 
 | ||
|     print("\n--- 测试 Instruction (内含BatchCollectCommand) ---")
 | ||
|     instruction_batch_collect = encode_instruction('batch_collect_command', batch_collect_data)
 | ||
|     print(f"编码后 Instruction (BatchCollectCommand): {instruction_batch_collect.hex()}")
 | ||
|     decoded_instruction_batch_collect = decode_instruction(instruction_batch_collect)
 | ||
|     print(f"解码后 Instruction (BatchCollectCommand): {decoded_instruction_batch_collect}")
 | ||
|     assert decoded_instruction_batch_collect['batch_collect_command']['correlation_id'] == batch_collect_data['correlation_id']
 | ||
|     assert len(decoded_instruction_batch_collect['batch_collect_command']['tasks']) == len(batch_collect_data['tasks'])
 | ||
| 
 | ||
|     print("\n--- 测试 Instruction (内含CollectResult) ---")
 | ||
|     instruction_collect_result = encode_instruction('collect_result', collect_result_data)
 | ||
|     print(f"编码后 Instruction (CollectResult): {instruction_collect_result.hex()}")
 | ||
|     decoded_instruction_collect_result = decode_instruction(instruction_collect_result)
 | ||
|     print(f"解码后 Instruction (CollectResult): {decoded_instruction_collect_result}")
 | ||
|     assert decoded_instruction_collect_result['collect_result']['correlation_id'] == collect_result_data['correlation_id']
 | ||
|     for i in range(len(collect_result_data['values'])):
 | ||
|         assert abs(decoded_instruction_collect_result['collect_result']['values'][i] - collect_result_data['values'][i]) < 1e-5
 | ||
| 
 | ||
|     print("\n所有测试均已通过!")
 |