secretflow.device.device.heu_object 源代码

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.tree_util
import ray

from secretflow.device.device.pyu import PYUObject

from .base import DeviceObject
from .register import dispatch


[文档]class HEUObject(DeviceObject): """HEU Object Attributes: data: The data hold by this Heu object location: The party where the data actually resides is_plain: Is the data encrypted or not """
[文档] def __init__( self, device, data: ray.ObjectRef, location_party: str, is_plain: bool = False, ): super().__init__(device) self.data = data self.is_plain = is_plain assert device.has_party( location_party ), f"{location_party} is not a party of HEU {id(device)}" self.location = location_party
def __str__(self): return f'is_plain:{self.is_plain}, location:{self.location}, {self.data}' def __add__(self, other): return dispatch('add', self, other) def __sub__(self, other): return dispatch('sub', self, other) def __mul__(self, other): return dispatch('mul', self, other) def __matmul__(self, other): return dispatch('matmul', self, other) def __rmatmul__(self, other): return dispatch('matmul', self, other) def __getitem__(self, item): item = jax.tree_util.tree_map( lambda x: x.data if isinstance(x, PYUObject) else x, item ) return HEUObject( self.device, self.device.get_participant(self.location).getitem.remote(self.data, item), self.location, self.is_plain, ) def __setitem__(self, key, value): if isinstance(key, PYUObject): key = key.data if isinstance(value, HEUObject): value = value.data return HEUObject( self.device, self.device.get_participant(self.location).setitem.remote( self.data, key, value ), self.location, self.is_plain, )
[文档] def encrypt(self, heu_audit_log: str = None): """Force encrypt if data is plaintext""" if self.is_plain: return HEUObject( self.device, self.device.get_participant(self.location).encrypt.remote( self.data, heu_audit_log ), self.location, False, ) else: return self
[文档] def sum(self): """ Sum of HeObject elements over a given axis. Returns: sum_along_axis """ return HEUObject( self.device, self.device.get_participant(self.location).sum.remote(self.data), self.location, self.is_plain, )
[文档] def dump(self, path): """Dump ciphertext into files.""" self.device.get_participant(self.location).dump.remote(self.data, path)
[文档] def select_sum(self, item): """ Sum of HEUObject selected elements """ item = jax.tree_util.tree_map( lambda x: x.data if isinstance(x, PYUObject) else x, item ) return HEUObject( self.device, self.device.get_participant(self.location).select_sum.remote( self.data, item ), self.location, self.is_plain, )
[文档] def batch_select_sum(self, item): """ Sum of HEUObject selected elements """ item = jax.tree_util.tree_map( lambda x: x.data if isinstance(x, PYUObject) else x, item ) return HEUObject( self.device, self.device.get_participant(self.location).batch_select_sum.remote( self.data, item ), self.location, self.is_plain, )
[文档] def feature_wise_bucket_sum( self, subgroup_map, order_map, bucket_num, cumsum=False ): """ Sum of HEUObject selected elements """ def process_data(x): res = x if isinstance(x, PYUObject): res = x.data return res subgroup_map = jax.tree_util.tree_map(process_data, subgroup_map) order_map = jax.tree_util.tree_map(process_data, order_map) bucket_num = process_data(bucket_num) return HEUObject( self.device, self.device.get_participant(self.location).feature_wise_bucket_sum.remote( self.data, subgroup_map, order_map, bucket_num, cumsum ), self.location, self.is_plain, )
[文档] def batch_feature_wise_bucket_sum( self, subgroup_map, order_map, bucket_num, cumsum=False ): """ Sum of HEUObject selected elements """ def process_data(x): res = x if isinstance(x, PYUObject): res = x.data return res subgroup_map = jax.tree_util.tree_map(process_data, subgroup_map) order_map = jax.tree_util.tree_map(process_data, order_map) bucket_num = process_data(bucket_num) return HEUObject( self.device, self.device.get_participant( self.location ).batch_feature_wise_bucket_sum.remote( self.data, subgroup_map, order_map, bucket_num, cumsum ), self.location, self.is_plain, )