secretflow.component.entry 源代码

# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from secretflow.component.ml.linear.ss_sgd import ss_sgd_predict_comp, ss_sgd_train_comp
from secretflow.component.preprocessing.train_test_split import train_test_split_comp
from secretflow.component.psi.two_party_balanced import two_party_balanced_psi_comp
from secretflow.protos.component.comp_def_pb2 import CompListDef
from secretflow.protos.component.node_def_pb2 import NodeDef

ALL_COMPONENTS = [
    train_test_split_comp,
    two_party_balanced_psi_comp,
    ss_sgd_train_comp,
    ss_sgd_predict_comp,
]
COMP_LIST_NAME = 'experimental'
COMP_LIST_DOC_STRING = 'Some experimental componments. Not production ready.'
COMP_LIST_VERSION = '0.0.1'


[文档]def gen_key(domain: str, name: str, version: str) -> str: return f'{domain}/{name}:{version}'
[文档]def generate_comp_list(): comp_list = CompListDef() comp_list.name = COMP_LIST_NAME comp_list.doc_string = COMP_LIST_DOC_STRING comp_list.version = COMP_LIST_VERSION comp_map = {} all_comp_defs = [] for x in ALL_COMPONENTS: x_def = x.definition() comp_map[gen_key(x_def.domain, x_def.name, x_def.version)] = x all_comp_defs.append(x_def) all_comp_defs = sorted(all_comp_defs, key=lambda k: (k.domain, k.name, k.version)) comp_list.comps.extend(all_comp_defs) return comp_list, comp_map
COMP_LIST, COMP_MAP = generate_comp_list()
[文档]def eval(instance: NodeDef, secretflow_cluster_config): key = gen_key(instance.domain, instance.name, instance.version) if key in COMP_MAP: comp = COMP_MAP[key] comp.eval(instance, secretflow_cluster_config) else: raise RuntimeError("component is not found.")