{ "cells": [ { "cell_type": "markdown", "id": "eb4f5fec-1a5c-4b6f-b583-fb369472e94b", "metadata": {}, "source": [ "# PSI On SPU" ] }, { "cell_type": "markdown", "id": "4172c7c4", "metadata": {}, "source": [ ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production." ] }, { "cell_type": "markdown", "id": "f9b632d8-b12f-44a1-8a75-5d9c0a704a38", "metadata": {}, "source": [ "PSI([Private Set Intersection](https://en.wikipedia.org/wiki/Private_set_intersection)) is a cryptographic technique that allows two parties holding sets to compare encrypted versions of these sets in order to compute the intersection. In this scenario, neither party reveals anything to the counterparty except for the elements in the intersection.\n", "\n", "In SecretFlow, SPU device supports three PSI protocol:\n", "\n", "- [ECDH](https://ieeexplore.ieee.org/document/6234849/):semi-honest, based on public key encryption, suitable for small datasets.\n", "- [KKRT](https://eprint.iacr.org/2016/799.pdf):semi-host, based on cuckoo hashing and OT extension, suitable for large datasets.\n", "- [BC22PCG](https://eprint.iacr.org/2022/334): semi-host, psi from pseudorandom correlation generators.\n", "\n", "Before we start, we need to initialize the environment. The following three nodes `alice`, `bob`, and `carol` are created on a single machine to simulate multiple participants." ] }, { "cell_type": "code", "execution_count": 1, "id": "3d7c4fa2-ea20-4e0d-b1ad-648cce23e729", "metadata": {}, "outputs": [], "source": [ "import secretflow as sf\n", "\n", "# In case you have a running secretflow runtime already.\n", "sf.shutdown()\n", "\n", "sf.init(['alice', 'bob', 'carol'], address='local')" ] }, { "cell_type": "markdown", "id": "00a798bd", "metadata": {}, "source": [] }, { "cell_type": "markdown", "id": "5ed0a08b-3aa4-4fa6-9e1d-0caba207bdf5", "metadata": {}, "source": [ "## Preparing dataset" ] }, { "cell_type": "markdown", "id": "b4c16f07-1c67-4bad-af70-d8a4fe9266f3", "metadata": {}, "source": [ "First, we need a dataset for constructing vertical partitioned scenarios. For simplicity, we use [iris](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html) dataset here. We add two columns to it for subsequent single-column and multi-column intersection demonstrations\n", "\n", "- uid:Sample unique ID.\n", "- month:Simulate a scenario where samples are generated monthly. The first 50% of the samples are generated in January, and the last 50% of the samples are generated in February." ] }, { "cell_type": "code", "execution_count": 2, "id": "31f0a010-0a2e-4ee2-996a-169d7cb2731d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)uidmonth
05.13.51.40.20Jan
14.93.01.40.21Jan
24.73.21.30.22Jan
34.63.11.50.23Jan
45.03.61.40.24Jan
.....................
1456.73.05.22.3145Feb
1466.32.55.01.9146Feb
1476.53.05.22.0147Feb
1486.23.45.42.3148Feb
1495.93.05.11.8149Feb
\n", "

150 rows × 6 columns

\n", "
" ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", "0 5.1 3.5 1.4 0.2 \n", "1 4.9 3.0 1.4 0.2 \n", "2 4.7 3.2 1.3 0.2 \n", "3 4.6 3.1 1.5 0.2 \n", "4 5.0 3.6 1.4 0.2 \n", ".. ... ... ... ... \n", "145 6.7 3.0 5.2 2.3 \n", "146 6.3 2.5 5.0 1.9 \n", "147 6.5 3.0 5.2 2.0 \n", "148 6.2 3.4 5.4 2.3 \n", "149 5.9 3.0 5.1 1.8 \n", "\n", " uid month \n", "0 0 Jan \n", "1 1 Jan \n", "2 2 Jan \n", "3 3 Jan \n", "4 4 Jan \n", ".. ... ... \n", "145 145 Feb \n", "146 146 Feb \n", "147 147 Feb \n", "148 148 Feb \n", "149 149 Feb \n", "\n", "[150 rows x 6 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "from sklearn.datasets import load_iris\n", "\n", "data, target = load_iris(return_X_y=True, as_frame=True)\n", "data['uid'] = np.arange(len(data)).astype('str')\n", "data['month'] = ['Jan'] * 75 + ['Feb'] * 75\n", "\n", "data" ] }, { "cell_type": "markdown", "id": "1bb263ff-e8a3-4066-ab30-ea01030a9f18", "metadata": {}, "source": [ "In the actual scenario, the sample data is provided by each participant, and the fields for intersection need to be agreed in advance:\n", "\n", "- The intersection field can be single or multiple.\n", "- The intersection field must be unique. If there is a duplicate, it needs to be deduplicated in advance.\n", "\n", "For example, The following is the data provided by alice for PSI intersection, the intersection field is `uid` and `month`,we can see that [1, 'Jan'] is duplicated.\n", "```\n", "alice.csv\n", "---------\n", "uid month c0\n", "1 Jan 5.8\n", "2 Jan 5.4\n", "1 Jan 5.8\n", "1 Feb 7.4\n", "```\n", "The data after deduplication is\n", "```\n", "alice.csv\n", "---------\n", "uid month c0\n", "1 Jan 5.8\n", "2 Jan 5.4\n", "1 Feb 7.4\n", "```\n", "We randomly sample the iris data three times to simulate the data provided by `alice`, `bob`, and `carol`, and the three data are in an unaligned state.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "037542dd-7945-4665-9d6a-7d805ea52b20", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.makedirs('.data', exist_ok=True)\n", "da, db, dc = data.sample(frac=0.9), data.sample(frac=0.8), data.sample(frac=0.7)\n", "\n", "da.to_csv('.data/alice.csv', index=False)\n", "db.to_csv('.data/bob.csv', index=False)\n", "dc.to_csv('.data/carol.csv', index=False)" ] }, { "cell_type": "markdown", "id": "d54451db-c9ba-45ca-877b-06619e03215f", "metadata": {}, "source": [ "## Two parties PSI" ] }, { "cell_type": "markdown", "id": "7c12512f-4889-4f71-a539-e5bb7f56a9a5", "metadata": {}, "source": [ "We virtualize three logical devices on the physical device:\n", "\n", "- alice, bob: PYU device, responsible for the local plaintext computation of the participant.\n", "- spu:SPUdevice, consists of alice and bob, responsible for the ciphertext calculation of the two parties." ] }, { "cell_type": "code", "execution_count": 4, "id": "ff729adb-f89a-499d-999f-6d884f2203e0", "metadata": { "tags": [] }, "outputs": [], "source": [ "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n", "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))" ] }, { "cell_type": "markdown", "id": "fb2b161a-5aa9-4d06-b866-48b25273e48b", "metadata": {}, "source": [ "### Single-column PSI" ] }, { "cell_type": "markdown", "id": "c12444e3-1da0-426e-add2-609770f8f259", "metadata": {}, "source": [ "Next, we use `uid` to intersect the two data, SPU provide `psi_csv` which take the csv file as input and generate the csv file after the intersection. The default protocol is KKRT." ] }, { "cell_type": "code", "execution_count": 5, "id": "dec15656-51c1-4498-9f78-b2bcfd5168fb", "metadata": { "tags": [] }, "outputs": [], "source": [ "input_path = {alice: '.data/alice.csv', bob: '.data/bob.csv'}\n", "output_path = {alice: '.data/alice_psi.csv', bob: '.data/bob_psi.csv'}\n", "spu.psi_csv('uid', input_path, output_path, 'alice')" ] }, { "cell_type": "markdown", "id": "1ef80302-548f-4277-a460-fee0a07b7728", "metadata": {}, "source": [ "To check the correctness of the results, we use [pandas.DataFrame.join](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.join.html) to inner join da and db. It can be seen that the two data have been aligned according to `uid` and sorted according to their lexicographical order." ] }, { "cell_type": "code", "execution_count": 6, "id": "ec091c3a-83bc-41f4-85d5-351c9d98c643", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)uidmonth
06.33.36.02.5100Feb
15.82.75.11.9101Feb
27.13.05.92.1102Feb
36.32.95.61.8103Feb
46.53.05.82.2104Feb
.....................
1015.62.74.21.394Feb
1025.72.94.21.396Feb
1036.22.94.31.397Feb
1045.12.53.01.198Feb
1055.72.84.11.399Feb
\n", "

106 rows × 6 columns

\n", "
" ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n", "0 6.3 3.3 6.0 2.5 \n", "1 5.8 2.7 5.1 1.9 \n", "2 7.1 3.0 5.9 2.1 \n", "3 6.3 2.9 5.6 1.8 \n", "4 6.5 3.0 5.8 2.2 \n", ".. ... ... ... ... \n", "101 5.6 2.7 4.2 1.3 \n", "102 5.7 2.9 4.2 1.3 \n", "103 6.2 2.9 4.3 1.3 \n", "104 5.1 2.5 3.0 1.1 \n", "105 5.7 2.8 4.1 1.3 \n", "\n", " uid month \n", "0 100 Feb \n", "1 101 Feb \n", "2 102 Feb \n", "3 103 Feb \n", "4 104 Feb \n", ".. ... ... \n", "101 94 Feb \n", "102 96 Feb \n", "103 97 Feb \n", "104 98 Feb \n", "105 99 Feb \n", "\n", "[106 rows x 6 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "df = da.join(db.set_index('uid'), on='uid', how='inner', rsuffix='_bob', sort=True)\n", "expected = df[da.columns].astype({'uid': 'int64'}).reset_index(drop=True)\n", "\n", "da_psi = pd.read_csv('.data/alice_psi.csv')\n", "db_psi = pd.read_csv('.data/bob_psi.csv')\n", "\n", "pd.testing.assert_frame_equal(da_psi, expected)\n", "pd.testing.assert_frame_equal(db_psi, expected)\n", "\n", "da_psi" ] }, { "cell_type": "markdown", "id": "4e1015b1-4062-4f40-8c5b-3b3e346cd4d2", "metadata": {}, "source": [ "### Multi-columns PSI" ] }, { "cell_type": "markdown", "id": "26bdf77c-dd58-4053-a3f4-f69659c8c96e", "metadata": {}, "source": [ "We can also use multiple fields to intersect, the following demonstrates the use of `uid` and `month` to intersect two data. In terms of implementation, multiple fields are concatenated into a string, so please ensure that there is no duplication of the multi-column composite primary key." ] }, { "cell_type": "code", "execution_count": 7, "id": "db24b582-ef58-4791-89d5-074be619d23f", "metadata": { "tags": [] }, "outputs": [], "source": [ "spu.psi_csv(['uid', 'month'], input_path, output_path, 'alice')" ] }, { "cell_type": "markdown", "id": "f7b34956-48c7-4e0f-bb6a-013508866267", "metadata": {}, "source": [ "Similarly, we use pandas.DataFrame.join to verify the correctness of the result, we can see that the two data have been aligned according to `uid` and `month`, and sorted according to their lexicographical order." ] }, { "cell_type": "code", "execution_count": 8, "id": "aebb3a76-977b-4856-b17d-066fd43230c8", "metadata": {}, "outputs": [], "source": [ "df = da.join(db.set_index(['uid', 'month']), on=['uid', 'month'], how='inner', rsuffix='_bob', sort=True)\n", "expected = df[da.columns].astype({'uid': 'int64'}).reset_index(drop=True)\n", "\n", "da_psi = pd.read_csv('.data/alice_psi.csv')\n", "db_psi = pd.read_csv('.data/bob_psi.csv')\n", "\n", "pd.testing.assert_frame_equal(da_psi, expected)\n", "pd.testing.assert_frame_equal(db_psi, expected)" ] }, { "cell_type": "markdown", "id": "f87f2dc4-c6c8-409a-b83d-fa15661def83", "metadata": {}, "source": [ "## Three parties PSI" ] }, { "cell_type": "markdown", "id": "82d884a3-4c6b-47ac-8526-ff486e89bcd4", "metadata": {}, "source": [ "Next, we add a third-party `carol`, and create a PYU device for it, as well as an SPU device built by the third party." ] }, { "cell_type": "code", "execution_count": 9, "id": "7eb7050f-70e4-4dc6-9f07-b19b8aeb69bc", "metadata": { "tags": [] }, "outputs": [], "source": [ "carol = sf.PYU('carol')\n", "spu_3pc = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob', 'carol']))" ] }, { "cell_type": "markdown", "id": "eee3d13d-c648-4072-9ebd-c50787701262", "metadata": {}, "source": [ "Then, use `uid` and `month` as the composite primary key to perform a three-way negotiation. It should be noted that the three-way negotiation only supports the ECDH protocol for the time being." ] }, { "cell_type": "code", "execution_count": 10, "id": "91f58d96-ef7c-411d-9cc4-524ef1c0dd44", "metadata": { "tags": [] }, "outputs": [], "source": [ "input_path = {alice: '.data/alice.csv', bob: '.data/bob.csv', carol: '.data/carol.csv'}\n", "output_path = {alice: '.data/alice_psi.csv', bob: '.data/bob_psi.csv', carol: '.data/carol_psi.csv'}\n", "spu_3pc.psi_csv(['uid', 'month'], input_path, output_path, 'alice', protocol='ECDH_PSI_3PC')" ] }, { "cell_type": "markdown", "id": "17461ac1-1409-45ef-a513-ecfe6482feb8", "metadata": {}, "source": [ "Similarly, we use pandas.DataFrame.join to verify the correctness of the result." ] }, { "cell_type": "code", "execution_count": 11, "id": "b6ea04f0-6b8a-40d7-b07a-06e926884261", "metadata": {}, "outputs": [], "source": [ "keys = ['uid', 'month']\n", "df = da.join(db.set_index(keys), on=keys, how='inner', rsuffix='_bob', sort=True).join(\n", " dc.set_index(keys), on=keys, how='inner', rsuffix='_carol', sort=True)\n", "expected = df[da.columns].astype({'uid': 'int64'}).reset_index(drop=True)\n", "\n", "da_psi = pd.read_csv('.data/alice_psi.csv')\n", "db_psi = pd.read_csv('.data/bob_psi.csv')\n", "dc_psi = pd.read_csv('.data/carol_psi.csv')\n", "\n", "pd.testing.assert_frame_equal(da_psi, expected)\n", "pd.testing.assert_frame_equal(db_psi, expected)\n", "pd.testing.assert_frame_equal(dc_psi, expected)" ] }, { "cell_type": "markdown", "id": "6f4bb1ed-2530-46c4-b540-8c779dd93439", "metadata": {}, "source": [ "## What's Next" ] }, { "cell_type": "markdown", "id": "55afbe9a-2855-41d6-b867-ec61c4ba4d20", "metadata": {}, "source": [ "OK! Through this tutorial, we have seen how to do two-party and three-party data intersections via SPU. After completing the data intersection, we can perform machine learning modeling on the aligned dataset.\n", "\n", "- [Logistic Regression On SPU](./lr_with_spu.ipynb): Logistic regression modeling on SPU using JAX.\n", "- [Neural Network on SPU](./nn_with_spu.ipynb): Neural Network Modeling on SPU with JAX.\n", "- [Basic Split Learning](./split_learning_gnn.ipynb): Neural Network Modeling with TensorFlow and Split Learning." ] } ], "metadata": { "interpreter": { "hash": "885f952f3a9a8d4529917f03b3eb9071e6c6f6c8d203379e2eea819d4a2f2375" }, "kernelspec": { "display_name": "Python 3.8.12 ('secretflow')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }