Source code for spu.api
# Copyright 2021 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import List
from cachetools import LRUCache, cached
from . import libspu # type: ignore
from . import spu_pb2
[docs]class Runtime(object):
"""The SPU Virtual Machine Slice."""
def __init__(self, link: libspu.link.Context, config: spu_pb2.RuntimeConfig):
"""Constructor of an SPU Runtime.
Args:
link (libspu.link.Context): Link context.
config (spu_pb2.RuntimeConfig): SPU Runtime Config.
"""
self._vm = libspu.RuntimeWrapper(link, config.SerializeToString())
[docs] def run(self, executable: spu_pb2.ExecutableProto) -> None:
"""Run an SPU executable.
Args:
executable (spu_pb2.ExecutableProto): executable.
"""
return self._vm.Run(executable.SerializeToString())
[docs] def set_var(self, name: str, value: bytes) -> None:
"""Set an SPU value.
Args:
name (str): Id of value.
value (spu_pb2.ValueProto): value data.
"""
return self._vm.SetVar(name, value)
[docs] def get_var(self, name: str) -> bytes:
"""Get an SPU value.
Args:
name (str): Id of value.
Returns:
spu_pb2.ValueProto: Data data.
"""
return self._vm.GetVar(name)
[docs] def del_var(self, name: str) -> None:
"""Delete an SPU value.
Args:
name (str): Id of the value.
"""
self._vm.DelVar(name)
[docs] def clear(self) -> None:
"""Delete all SPU values."""
self._vm.Clear()
[docs]class Io(object):
"""The SPU IO interface."""
def __init__(self, world_size: int, config: spu_pb2.RuntimeConfig):
"""Constructor of an SPU Io.
Args:
world_size (int): # of participants of SPU Device.
config (spu_pb2.RuntimeConfig): SPU Runtime Config.
"""
self._io = libspu.IoWrapper(world_size, config.SerializeToString())
[docs] def make_shares(
self, x: 'np.ndarray', vtype: spu_pb2.Visibility, owner_rank: int = -1
) -> List[bytes]:
"""Convert from NumPy array to list of SPU value(s).
Args:
x (np.ndarray): input.
vtype (spu_pb2.Visibility): visibility.
owner_rank (int): the index of the trusted piece. if >= 0, colocation optimization may be applied.
Returns:
[spu_pb2.ValueProto]: output.
"""
return self._io.MakeShares(x, vtype, owner_rank)
[docs] def reconstruct(self, str_shares: List[bytes]) -> 'np.ndarray':
"""Convert from list of SPU value(s) to NumPy array.
Args:
xs (spu_pb2.ValueProto]): input.
Returns:
np.ndarray: output.
"""
return self._io.Reconstruct(str_shares)
@cached(cache=LRUCache(maxsize=128))
def _spu_compilation(ir_text: str, ir_type: str, json_meta: str):
pp_dir = os.getenv('SPU_IR_DUMP_DIR')
return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "")
[docs]def compile(ir_text: str, ir_type: str, vis: List[spu_pb2.Visibility]) -> str:
"""Compile from textual HLO/MHLO IR to SPU bytecode.
Args:
ir_text (str): textual HLO/MHLO IR protobuf binary format.
ir_type (str): "hlo" or "mhlo".
vtype (spu_pb2.Visibility): Visbilities .
Returns:
[spu_pb2.ValueProto]: output.
"""
from google.protobuf.json_format import MessageToJson
# todo: rename spu_pb2.XlaMeta to IrMeta?
return _spu_compilation(
ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis))
)