secretflow.distributed.primitive 源代码

# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import inspect
from functools import partial
from typing import List, Union

import fed
import jax
import ray
from ray import Language
from ray._private import ray_option_utils
from ray.actor import ActorClass, _inject_tracing_into_class, ray_constants
from ray.remote_function import RemoteFunction

from secretflow.utils.ray_compatibility import ray_version_less_than_2_0_0

_production_mode = False


[文档]def set_production(mode: bool): global _production_mode _production_mode = mode
[文档]def production_mode(): global _production_mode return _production_mode
def _is_cython(obj): """Check if an object is a Cython function or method""" # TODO(suo): We could split these into two functions, one for Cython # functions and another for Cython methods. # TODO(suo): There doesn't appear to be a Cython function 'type' we can # check against via isinstance. Please correct me if I'm wrong. def check_cython(x): return type(x).__name__ == "cython_function_or_method" # Check if function or method, respectively return check_cython(obj) or ( hasattr(obj, "__func__") and check_cython(obj.__func__) )
[文档]def remote(*args, **kwargs): if production_mode(): return fed.remote(*args, **kwargs) else: return ray_remote(*args, **kwargs)
[文档]def get( object_refs: Union[ Union[ray.ObjectRef, List[ray.ObjectRef]], Union[fed.FedObject, List[fed.FedObject]], ] ): if production_mode(): return fed.get(object_refs) else: return ray.get(object_refs)
[文档]def kill(actor, *, no_restart=True): if production_mode(): return fed.kill(actor, no_restart=no_restart) else: return ray.kill(actor, no_restart=no_restart)
[文档]def shutdown(): if production_mode(): return fed.shutdown() else: return ray.shutdown()
def _resolve_args(*args, **kwargs): arg_flat, arg_tree = jax.tree_util.tree_flatten((args, kwargs)) refs = { pos: arg for pos, arg in enumerate(arg_flat) if isinstance(arg, ray.ObjectRef) } actual_vals = ray.get(list(refs.values())) for pos, actual_val in zip(refs.keys(), actual_vals): arg_flat[pos] = actual_val args, kwargs = jax.tree_util.tree_unflatten(arg_tree, arg_flat) return args, kwargs
[文档]class RemoteFunctionWrapper(RemoteFunction): def _remote(self, *args, **kwargs): args, kwargs = _resolve_args(*args, **kwargs) return super()._remote(*args, **kwargs)
[文档] def party(self, party: str): self.party = party if 'resources' in self._default_options: self._default_options['resources'].update({self.party: 1}) else: self._default_options.update({'resources': {self.party: 1}}) return self
[文档] def options(self, **task_options): if hasattr(self, 'party') and self.party: if 'resources' in task_options: task_options['resources'].update({self.party: 1}) else: task_options.update({'resources': {self.party: 1}}) return super().options(**task_options)
[文档]class ActorClassWrapper(ActorClass):
[文档] def party(self, party: str): self.party = party if 'resources' in self._default_options: self._default_options['resources'].update({self.party: 1}) else: self._default_options.update({'resources': {self.party: 1}}) return self
[文档] def options(self, **actor_options): if hasattr(self, 'party') and self.party: if 'resources' in actor_options: actor_options['resources'].update({self.party: 1}) else: actor_options.update({'resources': {self.party: 1}}) return super().options(**actor_options)
[文档] def remote(self, *args, **kwargs): args, kwargs = _resolve_args(*args, **kwargs) return super().remote(*args, **kwargs)
def _make_actor(cls, actor_options): if ray_version_less_than_2_0_0(): from ray import ActorClassID from ray.actor import modify_class as _modify_class else: from ray.actor import ActorClassID, _modify_class Class = _modify_class(cls) _inject_tracing_into_class(Class) if "max_restarts" in actor_options: if actor_options["max_restarts"] != -1: # -1 represents infinite restart # Make sure we don't pass too big of an int to C++, causing # an overflow. actor_options["max_restarts"] = min( actor_options["max_restarts"], ray_constants.MAX_INT64_VALUE ) return ActorClassWrapper._ray_from_modified_class( Class, ActorClassID.from_random(), actor_options, ) def _make_remote(function_or_class, options): if inspect.isfunction(function_or_class) or _is_cython(function_or_class): ray_option_utils.validate_task_options(options, in_options=False) return RemoteFunctionWrapper( Language.PYTHON, function_or_class, None, options, ) if inspect.isclass(function_or_class): ray_option_utils.validate_actor_options(options, in_options=False) return _make_actor(function_or_class, options) raise TypeError( "The @ray.remote decorator must be applied to either a function or a class." )
[文档]def ray_remote(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): # This is the case where the decorator is just @ray.remote. # "args[0]" is the class or function under the decorator. return _make_remote(args[0], {}) assert len(args) == 0 and len(kwargs) > 0, ray_option_utils.remote_args_error_string return partial(_make_remote, options=kwargs)