secretflow.kuscia.task_config 源代码
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Union
import spu
import yaml
from google.protobuf import json_format
from secretflow.protos.component.node_def_pb2 import NodeDef
from secretflow.protos.kuscia.kuscia_task_pb2 import AllocatedPorts, ClusterDefine
[文档]@dataclass
class TaskConfig:
task_id: str = None
party_id: int = None
party_name: str = None
ray_node_ip_address: str = None
ray_gcs_port: int = None
ray_node_manager_port: int = None
ray_object_manager_port: int = None
ray_client_server_port: int = None
ray_worker_ports: List[int] = None
spu_port: int = None
fed_port: int = None
cluster_config: Dict[str, Union[str, Dict[str, str]]] = None
spu_cluster_config: Dict[
str,
Union[
List[str],
Dict[str, Union[List[Dict[str, str]], Dict[str, Union[int, bool]]]],
],
] = None
comp_node: NodeDef = None
[文档] def parse_from_file(self, task_config_path: str):
with open(task_config_path) as f:
cluster_define = ClusterDefine()
allocated_port = AllocatedPorts()
self.comp_node = NodeDef()
configs = yaml.safe_load(f)
self.task_id = configs['task_id']
json_format.Parse(configs['task_input_config'], self.comp_node)
json_format.Parse(configs['task_input_cluster_def'], cluster_define)
json_format.Parse(configs['allocated_ports'], allocated_port)
self.party_id = cluster_define.self_party_idx
self.party_name = cluster_define.parties[self.party_id].name
self.ray_worker_ports = []
self.cluster_config = {}
self.spu_cluster_config = {}
self.spu_cluster_config["pyu"] = []
self.spu_cluster_config["spu"] = {"nodes": []}
for port in allocated_port.ports:
if port.name.startswith('ray-worker'):
self.ray_worker_ports.append(port.port)
elif port.name == 'spu':
self.spu_port = port.port
elif port.name == 'node-manager':
self.ray_node_manager_port = port.port
elif port.name == 'object-manager':
self.ray_object_manager_port = port.port
elif port.name == 'client-server':
self.ray_client_server_port = port.port
elif port.name == 'fed':
self.fed_port = port.port
self.cluster_config['parties'] = {}
for party in cluster_define.parties:
self.spu_cluster_config["pyu"].append(party.name)
if party.name != self.party_name:
for service in party.services:
if service.port_name == 'fed':
if len(service.endpoints[0].split(':')) < 2:
service.endpoints[0] += ':80'
self.cluster_config['parties'][party.name] = {
'address': service.endpoints[0]
}
elif service.port_name == 'spu':
if len(service.endpoints[0].split(':')) < 2:
service.endpoints[0] += ':80'
self.spu_cluster_config["spu"]["nodes"].append(
{
'party': party.name,
'id': f'{party.name}:0',
# add "http://" to force brpc to set the correct Host
'address': f'http://{service.endpoints[0]}',
}
)
else:
for service in cluster_define.parties[self.party_id].services:
if service.port_name == 'global':
segs = service.endpoints[0].split(':')
self.ray_node_ip_address = segs[0]
if len(segs) == 2:
self.ray_gcs_port = int(segs[1])
else:
self.ray_gcs_port = 80
self.cluster_config['parties'][self.party_name] = {
'address': f'0.0.0.0:{self.fed_port}'
}
self.cluster_config['self_party'] = self.party_name
self.spu_cluster_config['spu']['nodes'].append(
{
'party': self.party_name,
'id': f'{self.party_name}:0',
'address': f'0.0.0.0:{self.spu_port}',
}
)
self.spu_cluster_config['spu']['nodes'].sort(key=lambda x: x['party'])
if len(cluster_define.parties) == 2:
self.spu_cluster_config['spu']['runtime_config'] = {
'protocol': spu.spu_pb2.SEMI2K,
'field': spu.spu_pb2.FM128,
}
elif len(cluster_define.parties) == 3:
self.spu_cluster_config['spu']['runtime_config'] = {
'protocol': spu.spu_pb2.ABY3,
'field': spu.spu_pb2.FM64,
}
return self