C++ Example: Logistic Regression#

To use SPU C++ API, we have to first From Source, this document shows how to write a privacy preserving logistic regression program with SPU C++ API.

Logistic Regression#

// Copyright 2021 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// clang-format off
// To run the example, start two terminals:
// > bazel run //examples/cpp:simple_lr -- --dataset=examples/data/perfect_logit_a.csv --has_label=true
// > bazel run //examples/cpp:simple_lr -- --dataset=examples/data/perfect_logit_b.csv --rank=1
// clang-format on

#include <fstream>
#include <iostream>
#include <vector>

#include "examples/cpp/utils.h"
#include "spdlog/spdlog.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xcsv.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"

#include "libspu/device/io.h"
#include "libspu/kernel/hal/hal.h"
#include "libspu/kernel/hal/type_cast.h"

using namespace spu::kernel;

spu::Value train_step(spu::HalContext* ctx, const spu::Value& x,
                      const spu::Value& y, const spu::Value& w) {
  // Padding x
  auto padding = hal::constant(ctx, 1.0F, spu::DT_FXP, {x.shape()[0], 1});
  auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1);
  auto pred = hal::logistic(ctx, hal::matmul(ctx, padded_x, w));

  SPDLOG_DEBUG("[SSLR] Err = Pred - Y");
  auto err = hal::sub(ctx, pred, y);

  SPDLOG_DEBUG("[SSLR] Grad = X.t * Err");
  auto grad = hal::matmul(ctx, hal::transpose(ctx, padded_x), err);

  SPDLOG_DEBUG("[SSLR] Step = LR / B * Grad");
  auto lr = hal::constant(ctx, 0.0001F, spu::DT_FXP);
  auto msize =
      hal::constant(ctx, static_cast<float>(y.shape()[0]), spu::DT_FXP);
  auto p1 = hal::mul(ctx, lr, hal::reciprocal(ctx, msize));
  auto step = hal::mul(ctx, hal::broadcast_to(ctx, p1, grad.shape()), grad);

  SPDLOG_DEBUG("[SSLR] W = W - Step");
  auto new_w = hal::sub(ctx, w, step);

  return new_w;
}

spu::Value train(spu::HalContext* ctx, const spu::Value& x, const spu::Value& y,
                 size_t num_epoch, size_t bsize) {
  const size_t num_iter = x.shape()[0] / bsize;
  auto w = hal::constant(ctx, 0.0F, spu::DT_FXP, {x.shape()[1] + 1, 1});

  // Run train loop
  for (size_t epoch = 0; epoch < num_epoch; ++epoch) {
    for (size_t iter = 0; iter < num_iter; ++iter) {
      SPDLOG_INFO("Running train iteration {}", iter);

      const int64_t rows_beg = iter * bsize;
      const int64_t rows_end = rows_beg + bsize;

      const auto x_slice =
          hal::slice(ctx, x, {rows_beg, 0}, {rows_end, x.shape()[1]}, {});

      const auto y_slice =
          hal::slice(ctx, y, {rows_beg, 0}, {rows_end, y.shape()[1]}, {});

      w = train_step(ctx, x_slice, y_slice, w);
    }
  }

  return w;
}

spu::Value inference(spu::HalContext* ctx, const spu::Value& x,
                     const spu::Value& weight) {
  auto padding = hal::constant(ctx, 1.0F, spu::DT_FXP, {x.shape()[0], 1});
  auto padded_x = hal::concatenate(ctx, {x, hal::seal(ctx, padding)}, 1);
  return hal::matmul(ctx, padded_x, weight);
}

float SSE(const xt::xarray<float>& y_true, const xt::xarray<float>& y_pred) {
  float sse = 0;

  for (auto y_true_iter = y_true.begin(), y_pred_iter = y_pred.begin();
       y_true_iter != y_true.end() && y_pred_iter != y_pred.end();
       ++y_pred_iter, ++y_true_iter) {
    sse += std::pow(*y_true_iter - *y_pred_iter, 2);
  }
  return sse;
}

float MSE(const xt::xarray<float>& y_true, const xt::xarray<float>& y_pred) {
  auto sse = SSE(y_true, y_pred);

  return sse / static_cast<float>(y_true.size());
}

llvm::cl::opt<std::string> Dataset("dataset", llvm::cl::init("data.csv"),
                                   llvm::cl::desc("only csv is supported"));
llvm::cl::opt<uint32_t> SkipRows(
    "skip_rows", llvm::cl::init(1),
    llvm::cl::desc("skip number of rows from dataset"));
llvm::cl::opt<bool> HasLabel(
    "has_label", llvm::cl::init(false),
    llvm::cl::desc("if true, label is the last column of dataset"));
llvm::cl::opt<uint32_t> BatchSize("batch_size", llvm::cl::init(21),
                                  llvm::cl::desc("size of each batch"));
llvm::cl::opt<uint32_t> NumEpoch("num_epoch", llvm::cl::init(1),
                                 llvm::cl::desc("number of epoch"));

std::pair<spu::Value, spu::Value> infeed(spu::HalContext* hctx,
                                         const xt::xarray<float>& ds,
                                         bool self_has_label) {
  spu::device::ColocatedIo cio(hctx);
  if (self_has_label) {
    // the last column is label.
    using namespace xt::placeholders;  // required for `_` to work
    xt::xarray<float> dx =
        xt::view(ds, xt::all(), xt::range(_, ds.shape(1) - 1));
    xt::xarray<float> dy =
        xt::view(ds, xt::all(), xt::range(ds.shape(1) - 1, _));
    cio.hostSetVar(fmt::format("x-{}", hctx->lctx()->Rank()), dx);
    cio.hostSetVar("label", dy);
  } else {
    cio.hostSetVar(fmt::format("x-{}", hctx->lctx()->Rank()), ds);
  }
  cio.sync();

  auto x = cio.deviceGetVar("x-0");
  // Concatnate all slices
  for (size_t idx = 1; idx < cio.getWorldSize(); ++idx) {
    x = hal::concatenate(hctx, {x, cio.deviceGetVar(fmt::format("x-{}", idx))},
                         1);
  }
  auto y = cio.deviceGetVar("label");

  return std::make_pair(x, y);
}

int main(int argc, char** argv) {
  llvm::cl::ParseCommandLineOptions(argc, argv);
  // read dataset.
  xt::xarray<float> ds;
  {
    std::ifstream file(Dataset.getValue());
    if (!file) {
      spdlog::error("open file={} failed", Dataset.getValue());
      exit(-1);
    }
    ds = xt::load_csv<float>(file, ',', SkipRows.getValue());
  }

  auto hctx = MakeHalContext();

  const auto& [x, y] = infeed(hctx.get(), ds, HasLabel.getValue());

  const auto w =
      train(hctx.get(), x, y, NumEpoch.getValue(), BatchSize.getValue());

  const auto scores = inference(hctx.get(), x, w);

  xt::xarray<float> revealed_labels =
      hal::dump_public_as<float>(hctx.get(), hal::reveal(hctx.get(), y));
  xt::xarray<float> revealed_scores =
      hal::dump_public_as<float>(hctx.get(), hal::reveal(hctx.get(), scores));

  auto mse = MSE(revealed_labels, revealed_scores);
  std::cout << "MSE = " << mse << "\n";

  return 0;
}

Run it#

Start two terminals.

In the first terminal.

bazel run //examples/cpp:simple_lr -- -rank 0 -dataset examples/cpp/data/perfect_logit_a.csv -has_label=true

In the second terminal.

bazel run //examples/cpp:simple_lr -- -rank 1 -dataset examples/cpp/data/perfect_logit_b.csv