更新proto
This commit is contained in:
542
client_pb.py
542
client_pb.py
@@ -8,9 +8,7 @@
|
||||
|
||||
import struct
|
||||
|
||||
# MethodType枚举
|
||||
METHOD_TYPE_SWITCH = 0
|
||||
METHOD_TYPE_COLLECT = 1
|
||||
# --- Helper Functions for Protobuf Basic Types ---
|
||||
|
||||
def encode_varint(value):
|
||||
"""编码varint值"""
|
||||
@@ -47,260 +45,336 @@ def decode_string(buf, pos=0):
|
||||
pos += length
|
||||
return value, pos
|
||||
|
||||
def encode_instruction(method, data):
|
||||
# --- Message Encoding/Decoding Functions ---
|
||||
|
||||
def encode_raw_485_command(bus_number, command_bytes):
|
||||
"""
|
||||
编码Instruction消息
|
||||
|
||||
编码Raw485Command消息
|
||||
Args:
|
||||
method: 方法类型 (int)
|
||||
data: 数据 (bytes)
|
||||
|
||||
bus_number: 总线号 (int)
|
||||
command_bytes: 原始485指令的字节数组 (bytes)
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
|
||||
# 编码method字段 (field_number=1, wire_type=0)
|
||||
result.extend(encode_varint((1 << 3) | 0)) # tag
|
||||
result.extend(encode_varint(method)) # value
|
||||
|
||||
# 编码data字段 (field_number=2, wire_type=2)
|
||||
result.extend(encode_varint((2 << 3) | 2)) # tag
|
||||
result.extend(encode_varint(len(data))) # length
|
||||
result.extend(data) # value
|
||||
|
||||
# bus_number (field 1, wire type 0)
|
||||
result.extend(encode_varint((1 << 3) | 0))
|
||||
result.extend(encode_varint(bus_number))
|
||||
# command_bytes (field 2, wire type 2)
|
||||
result.extend(encode_varint((2 << 3) | 2))
|
||||
result.extend(encode_varint(len(command_bytes)))
|
||||
result.extend(command_bytes)
|
||||
return result
|
||||
|
||||
def decode_raw_485_command(buf):
|
||||
"""
|
||||
解码Raw485Command消息
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
while pos < len(buf):
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 1: # bus_number
|
||||
if wire_type == 0:
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['bus_number'] = value
|
||||
elif field_number == 2: # command_bytes
|
||||
if wire_type == 2:
|
||||
length, pos = decode_varint(buf, pos)
|
||||
value = buf[pos:pos+length]
|
||||
pos += length
|
||||
result['command_bytes'] = value
|
||||
else:
|
||||
# 跳过未知字段
|
||||
if wire_type == 0: _, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
|
||||
else: pos += 1
|
||||
return result
|
||||
|
||||
def encode_collect_task(command_msg):
|
||||
"""
|
||||
编码CollectTask消息
|
||||
Args:
|
||||
command_msg: Raw485Command消息字典
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
# command (field 2, wire type 2)
|
||||
encoded_command = encode_raw_485_command(command_msg['bus_number'], command_msg['command_bytes'])
|
||||
result.extend(encode_varint((2 << 3) | 2))
|
||||
result.extend(encode_varint(len(encoded_command)))
|
||||
result.extend(encoded_command)
|
||||
return result
|
||||
|
||||
def decode_collect_task(buf):
|
||||
"""
|
||||
解码CollectTask消息
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
while pos < len(buf):
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 2: # command
|
||||
if wire_type == 2:
|
||||
length, pos = decode_varint(buf, pos)
|
||||
value_buf = buf[pos:pos+length]
|
||||
pos += length
|
||||
result['command'] = decode_raw_485_command(value_buf)
|
||||
else:
|
||||
if wire_type == 0: _, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
|
||||
else: pos += 1
|
||||
return result
|
||||
|
||||
def encode_batch_collect_command(correlation_id, tasks):
|
||||
"""
|
||||
编码BatchCollectCommand消息
|
||||
Args:
|
||||
correlation_id: 关联ID (str)
|
||||
tasks: CollectTask消息字典列表
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
# correlation_id (field 1, wire type 2)
|
||||
result.extend(encode_varint((1 << 3) | 2))
|
||||
result.extend(encode_string(correlation_id))
|
||||
# tasks (field 2, wire type 2) - repeated
|
||||
for task in tasks:
|
||||
encoded_task = encode_collect_task(task['command'])
|
||||
result.extend(encode_varint((2 << 3) | 2))
|
||||
result.extend(encode_varint(len(encoded_task)))
|
||||
result.extend(encoded_task)
|
||||
return result
|
||||
|
||||
def decode_batch_collect_command(buf):
|
||||
"""
|
||||
解码BatchCollectCommand消息
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {'tasks': []}
|
||||
pos = 0
|
||||
while pos < len(buf):
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 1: # correlation_id
|
||||
if wire_type == 2:
|
||||
value, pos = decode_string(buf, pos)
|
||||
result['correlation_id'] = value
|
||||
elif field_number == 2: # tasks (repeated)
|
||||
if wire_type == 2:
|
||||
length, pos = decode_varint(buf, pos)
|
||||
value_buf = buf[pos:pos+length]
|
||||
pos += length
|
||||
result['tasks'].append(decode_collect_task(value_buf))
|
||||
else:
|
||||
if wire_type == 0: _, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
|
||||
else: pos += 1
|
||||
return result
|
||||
|
||||
def encode_collect_result(correlation_id, values):
|
||||
"""
|
||||
编码CollectResult消息
|
||||
Args:
|
||||
correlation_id: 关联ID (str)
|
||||
values: 采集值列表 (list of float)
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
# correlation_id (field 1, wire type 2)
|
||||
result.extend(encode_varint((1 << 3) | 2))
|
||||
result.extend(encode_string(correlation_id))
|
||||
# values (field 2, wire type 5) - repeated fixed32
|
||||
for value in values:
|
||||
result.extend(encode_varint((2 << 3) | 5)) # Tag for fixed32
|
||||
result.extend(struct.pack('<f', value)) # Little-endian float
|
||||
return result
|
||||
|
||||
def decode_collect_result(buf):
|
||||
"""
|
||||
解码CollectResult消息
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {'values': []}
|
||||
pos = 0
|
||||
while pos < len(buf):
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 1: # correlation_id
|
||||
if wire_type == 2:
|
||||
value, pos = decode_string(buf, pos)
|
||||
result['correlation_id'] = value
|
||||
elif field_number == 2: # values (repeated)
|
||||
if wire_type == 5: # fixed32
|
||||
value = struct.unpack('<f', buf[pos:pos+4])[0]
|
||||
pos += 4
|
||||
result['values'].append(value)
|
||||
else:
|
||||
if wire_type == 0: _, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 5: pos += 4 # fixed32
|
||||
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
|
||||
else: pos += 1
|
||||
return result
|
||||
|
||||
def encode_instruction(payload_type, payload_data):
|
||||
"""
|
||||
编码Instruction消息 (包含oneof字段)
|
||||
Args:
|
||||
payload_type: oneof字段的类型 ('raw_485_command', 'batch_collect_command', 'collect_result')
|
||||
payload_data: 对应类型的消息字典
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
encoded_payload = bytearray()
|
||||
|
||||
if payload_type == 'raw_485_command':
|
||||
encoded_payload = encode_raw_485_command(payload_data['bus_number'], payload_data['command_bytes'])
|
||||
result.extend(encode_varint((1 << 3) | 2)) # field 1, wire type 2
|
||||
elif payload_type == 'batch_collect_command':
|
||||
encoded_payload = encode_batch_collect_command(payload_data['correlation_id'], payload_data['tasks'])
|
||||
result.extend(encode_varint((2 << 3) | 2)) # field 2, wire type 2
|
||||
elif payload_type == 'collect_result':
|
||||
encoded_payload = encode_collect_result(payload_data['correlation_id'], payload_data['values'])
|
||||
result.extend(encode_varint((3 << 3) | 2)) # field 3, wire type 2
|
||||
else:
|
||||
raise ValueError("Unknown instruction payload type")
|
||||
|
||||
result.extend(encode_varint(len(encoded_payload)))
|
||||
result.extend(encoded_payload)
|
||||
return result
|
||||
|
||||
def decode_instruction(buf):
|
||||
"""
|
||||
解码Instruction消息
|
||||
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
|
||||
while pos < len(buf):
|
||||
# 读取标签
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 1: # method字段
|
||||
if wire_type == 0: # varint类型
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['method'] = value
|
||||
elif field_number == 2: # data字段
|
||||
if wire_type == 2: # 长度分隔类型
|
||||
length, pos = decode_varint(buf, pos)
|
||||
value = buf[pos:pos+length]
|
||||
pos += length
|
||||
result['data'] = value
|
||||
|
||||
if wire_type == 2: # Length-delimited type for all oneof fields
|
||||
length, pos = decode_varint(buf, pos)
|
||||
value_buf = buf[pos:pos+length]
|
||||
pos += length
|
||||
|
||||
if field_number == 1: # raw_485_command
|
||||
result['raw_485_command'] = decode_raw_485_command(value_buf)
|
||||
elif field_number == 2: # batch_collect_command
|
||||
result['batch_collect_command'] = decode_batch_collect_command(value_buf)
|
||||
elif field_number == 3: # collect_result
|
||||
result['collect_result'] = decode_collect_result(value_buf)
|
||||
# else: unknown field, already skipped by default behavior
|
||||
else:
|
||||
# 跳过未知字段
|
||||
if wire_type == 0: # varint
|
||||
_, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 2: # 长度分隔
|
||||
length, pos = decode_varint(buf, pos)
|
||||
pos += length
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
# 跳过未知字段 (或非长度分隔类型,尽管oneof字段通常是长度分隔的)
|
||||
if wire_type == 0: _, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 5: pos += 4 # fixed32
|
||||
elif wire_type == 2: length, pos = decode_varint(buf, pos); pos += length
|
||||
else: pos += 1
|
||||
return result
|
||||
|
||||
def encode_switch(device_action, bus_number, bus_address, relay_channel):
|
||||
"""
|
||||
编码Switch消息
|
||||
|
||||
Args:
|
||||
device_action: 设备动作指令 (str)
|
||||
bus_number: 总线号 (int)
|
||||
bus_address: 总线地址 (int)
|
||||
relay_channel: 继电器通道号 (int)
|
||||
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
|
||||
# 编码device_action字段 (field_number=1, wire_type=2)
|
||||
result.extend(encode_varint((1 << 3) | 2)) # tag
|
||||
action_bytes = encode_string(device_action) # value (length + string)
|
||||
result.extend(action_bytes)
|
||||
|
||||
# 编码bus_number字段 (field_number=2, wire_type=0)
|
||||
result.extend(encode_varint((2 << 3) | 0)) # tag
|
||||
result.extend(encode_varint(bus_number)) # value
|
||||
|
||||
# 编码bus_address字段 (field_number=3, wire_type=0)
|
||||
result.extend(encode_varint((3 << 3) | 0)) # tag
|
||||
result.extend(encode_varint(bus_address)) # value
|
||||
|
||||
# 编码relay_channel字段 (field_number=4, wire_type=0)
|
||||
result.extend(encode_varint((4 << 3) | 0)) # tag
|
||||
result.extend(encode_varint(relay_channel)) # value
|
||||
|
||||
return result
|
||||
|
||||
def decode_switch(buf):
|
||||
"""
|
||||
解码Switch消息
|
||||
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
|
||||
while pos < len(buf):
|
||||
# 读取标签
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 1: # device_action字段
|
||||
if wire_type == 2: # 字符串类型
|
||||
value, pos = decode_string(buf, pos)
|
||||
result['device_action'] = value
|
||||
elif field_number == 2: # bus_number字段
|
||||
if wire_type == 0: # varint类型
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['bus_number'] = value
|
||||
elif field_number == 3: # bus_address字段
|
||||
if wire_type == 0: # varint类型
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['bus_address'] = value
|
||||
elif field_number == 4: # relay_channel字段
|
||||
if wire_type == 0: # varint类型
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['relay_channel'] = value
|
||||
else:
|
||||
# 跳过未知字段
|
||||
if wire_type == 0: # varint
|
||||
_, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 2: # 长度分隔
|
||||
length, pos = decode_varint(buf, pos)
|
||||
pos += length
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
return result
|
||||
|
||||
def encode_collect(bus_number, bus_address, value):
|
||||
"""
|
||||
编码Collect消息
|
||||
|
||||
Args:
|
||||
bus_number: 总线号 (int)
|
||||
bus_address: 总线地址 (int)
|
||||
value: 采集值 (float)
|
||||
|
||||
Returns:
|
||||
bytearray: 编码后的数据
|
||||
"""
|
||||
result = bytearray()
|
||||
|
||||
# 编码bus_number字段 (field_number=1, wire_type=0)
|
||||
result.extend(encode_varint((1 << 3) | 0)) # tag
|
||||
result.extend(encode_varint(bus_number)) # value
|
||||
|
||||
# 编码bus_address字段 (field_number=2, wire_type=0)
|
||||
result.extend(encode_varint((2 << 3) | 0)) # tag
|
||||
result.extend(encode_varint(bus_address)) # value
|
||||
|
||||
# 编码value字段 (field_number=3, wire_type=5)
|
||||
result.extend(encode_varint((3 << 3) | 5)) # tag
|
||||
# 将float转换为little-endian的4字节
|
||||
result.extend(struct.pack('<f', value)) # value
|
||||
|
||||
return result
|
||||
|
||||
def decode_collect(buf):
|
||||
"""
|
||||
解码Collect消息
|
||||
|
||||
Args:
|
||||
buf: 编码后的数据 (bytes)
|
||||
|
||||
Returns:
|
||||
dict: 解码后的消息
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
|
||||
while pos < len(buf):
|
||||
# 读取标签
|
||||
tag, pos = decode_varint(buf, pos)
|
||||
field_number = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
|
||||
if field_number == 1: # bus_number字段
|
||||
if wire_type == 0: # varint类型
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['bus_number'] = value
|
||||
elif field_number == 2: # bus_address字段
|
||||
if wire_type == 0: # varint类型
|
||||
value, pos = decode_varint(buf, pos)
|
||||
result['bus_address'] = value
|
||||
elif field_number == 3: # value字段
|
||||
if wire_type == 5: # 32位浮点类型
|
||||
# 从little-endian的4字节解析float
|
||||
value = struct.unpack('<f', buf[pos:pos+4])[0]
|
||||
pos += 4
|
||||
result['value'] = value
|
||||
else:
|
||||
# 跳过未知字段
|
||||
if wire_type == 0: # varint
|
||||
_, pos = decode_varint(buf, pos)
|
||||
elif wire_type == 5: # 32位固定长度
|
||||
pos += 4
|
||||
elif wire_type == 2: # 长度分隔
|
||||
length, pos = decode_varint(buf, pos)
|
||||
pos += length
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
return result
|
||||
|
||||
# 使用示例
|
||||
# --- Usage Example ---
|
||||
if __name__ == "__main__":
|
||||
# 创建一个Switch消息
|
||||
switch_data = encode_switch("ON", 1, 10, 2)
|
||||
print(f"编码后的Switch消息: {switch_data.hex()}")
|
||||
|
||||
# 创建一个Instruction消息,包含Switch数据
|
||||
instruction_data = encode_instruction(METHOD_TYPE_SWITCH, switch_data)
|
||||
print(f"编码后的Instruction消息: {instruction_data.hex()}")
|
||||
|
||||
# 解码Instruction消息
|
||||
decoded_instruction = decode_instruction(instruction_data)
|
||||
print(f"解码后的Instruction消息: {decoded_instruction}")
|
||||
|
||||
# 解码Switch消息
|
||||
if 'data' in decoded_instruction:
|
||||
decoded_switch = decode_switch(decoded_instruction['data'])
|
||||
print(f"解码后的Switch消息: {decoded_switch}")
|
||||
|
||||
# 创建一个Collect消息
|
||||
collect_data = encode_collect(1, 20, 25.6)
|
||||
print(f"编码后的Collect消息: {collect_data.hex()}")
|
||||
|
||||
# 创建一个Instruction消息,包含Collect数据
|
||||
instruction_data2 = encode_instruction(METHOD_TYPE_COLLECT, collect_data)
|
||||
print(f"编码后的Instruction消息(Collect): {instruction_data2.hex()}")
|
||||
|
||||
# 解码Instruction消息
|
||||
decoded_instruction2 = decode_instruction(instruction_data2)
|
||||
print(f"解码后的Instruction消息: {decoded_instruction2}")
|
||||
|
||||
# 解码Collect消息
|
||||
if 'data' in decoded_instruction2:
|
||||
decoded_collect = decode_collect(decoded_instruction2['data'])
|
||||
print(f"解码后的Collect消息: {decoded_collect}")
|
||||
print("--- Testing Raw485Command ---")
|
||||
raw_cmd_data = {'bus_number': 1, 'command_bytes': b'\x01\x03\x00\x00\x00\x02\xc4\x0b'}
|
||||
encoded_raw_cmd = encode_raw_485_command(raw_cmd_data['bus_number'], raw_cmd_data['command_bytes'])
|
||||
print(f"Encoded Raw485Command: {encoded_raw_cmd.hex()}")
|
||||
decoded_raw_cmd = decode_raw_485_command(encoded_raw_cmd)
|
||||
print(f"Decoded Raw485Command: {decoded_raw_cmd}")
|
||||
assert decoded_raw_cmd == raw_cmd_data
|
||||
|
||||
print("\n--- Testing CollectTask ---")
|
||||
collect_task_data = {'command': raw_cmd_data}
|
||||
encoded_collect_task = encode_collect_task(collect_task_data['command'])
|
||||
print(f"Encoded CollectTask: {encoded_collect_task.hex()}")
|
||||
decoded_collect_task = decode_collect_task(encoded_collect_task)
|
||||
print(f"Decoded CollectTask: {decoded_collect_task}")
|
||||
assert decoded_collect_task == collect_task_data
|
||||
|
||||
print("\n--- Testing BatchCollectCommand ---")
|
||||
batch_collect_data = {
|
||||
'correlation_id': 'abc-123',
|
||||
'tasks': [
|
||||
{'command': {'bus_number': 1, 'command_bytes': b'\x01\x03\x00\x00\x00\x02\xc4\x0b'}},
|
||||
{'command': {'bus_number': 2, 'command_bytes': b'\x02\x03\x00\x01\x00\x01\xd5\xfa'}}
|
||||
]
|
||||
}
|
||||
encoded_batch_collect = encode_batch_collect_command(batch_collect_data['correlation_id'], batch_collect_data['tasks'])
|
||||
print(f"Encoded BatchCollectCommand: {encoded_batch_collect.hex()}")
|
||||
decoded_batch_collect = decode_batch_collect_command(encoded_batch_collect)
|
||||
print(f"Decoded BatchCollectCommand: {decoded_batch_collect}")
|
||||
assert decoded_batch_collect == batch_collect_data
|
||||
|
||||
print("\n--- Testing CollectResult ---")
|
||||
collect_result_data = {
|
||||
'correlation_id': 'res-456',
|
||||
'values': [12.34, 56.78, 90.12]
|
||||
}
|
||||
encoded_collect_result = encode_collect_result(collect_result_data['correlation_id'], collect_result_data['values'])
|
||||
print(f"Encoded CollectResult: {encoded_collect_result.hex()}")
|
||||
decoded_collect_result = decode_collect_result(encoded_collect_result)
|
||||
print(f"Decoded CollectResult: {decoded_collect_result}")
|
||||
# Due to float precision, direct assert might fail. Compare elements.
|
||||
assert decoded_collect_result['correlation_id'] == collect_result_data['correlation_id']
|
||||
for i in range(len(collect_result_data['values'])):
|
||||
assert abs(decoded_collect_result['values'][i] - collect_result_data['values'][i]) < 1e-6
|
||||
|
||||
print("\n--- Testing Instruction with Raw485Command ---")
|
||||
instruction_raw_485 = encode_instruction('raw_485_command', raw_cmd_data)
|
||||
print(f"Encoded Instruction (Raw485Command): {instruction_raw_485.hex()}")
|
||||
decoded_instruction_raw_485 = decode_instruction(instruction_raw_485)
|
||||
print(f"Decoded Instruction (Raw485Command): {decoded_instruction_raw_485}")
|
||||
assert decoded_instruction_raw_485['raw_485_command'] == raw_cmd_data
|
||||
|
||||
print("\n--- Testing Instruction with BatchCollectCommand ---")
|
||||
instruction_batch_collect = encode_instruction('batch_collect_command', batch_collect_data)
|
||||
print(f"Encoded Instruction (BatchCollectCommand): {instruction_batch_collect.hex()}")
|
||||
decoded_instruction_batch_collect = decode_instruction(instruction_batch_collect)
|
||||
print(f"Decoded Instruction (BatchCollectCommand): {decoded_instruction_batch_collect}")
|
||||
assert decoded_instruction_batch_collect['batch_collect_command']['correlation_id'] == batch_collect_data['correlation_id']
|
||||
assert len(decoded_instruction_batch_collect['batch_collect_command']['tasks']) == len(batch_collect_data['tasks'])
|
||||
# More detailed assertion for tasks if needed
|
||||
|
||||
print("\n--- Testing Instruction with CollectResult ---")
|
||||
instruction_collect_result = encode_instruction('collect_result', collect_result_data)
|
||||
print(f"Encoded Instruction (CollectResult): {instruction_collect_result.hex()}")
|
||||
decoded_instruction_collect_result = decode_instruction(instruction_collect_result)
|
||||
print(f"Decoded Instruction (CollectResult): {decoded_instruction_collect_result}")
|
||||
assert decoded_instruction_collect_result['collect_result']['correlation_id'] == collect_result_data['correlation_id']
|
||||
for i in range(len(collect_result_data['values'])):
|
||||
assert abs(decoded_instruction_collect_result['collect_result']['values'][i] - collect_result_data['values'][i]) < 1e-6
|
||||
|
||||
print("\nAll tests passed!")
|
||||
|
||||
Reference in New Issue
Block a user