secretflow.ml.nn.fl.backend.tensorflow#

secretflow.ml.nn.fl.backend.tensorflow.fl_base#

Classes:

BaseTFModel(builder_base[, random_seed])

class secretflow.ml.nn.fl.backend.tensorflow.fl_base.BaseTFModel(builder_base: Callable[[], Model], random_seed: Optional[int] = None)[源代码]#

基类:object

Methods:

__init__(builder_base[, random_seed])

build_dataset_from_csv(csv_file_path, label)

build tf.data.Dataset

build_dataset(x[, y, s_w, sampling_rate, ...])

build tf.data.Dataset

build_dataset_from_builder(dataset_builder, x)

build tf.data.Dataset

get_rows_count(filename)

get_weights()

set_weights(weights)

set weights of client model

set_validation_metrics(global_metrics)

wrap_local_metrics()

evaluate([evaluate_steps])

predict([predict_steps])

init_training(callbacks[, epochs, steps, ...])

on_train_begin()

on_epoch_begin(epoch)

on_epoch_end(epoch)

on_train_end()

get_stop_training()

train_step(weights, cur_steps, train_steps, ...)

save_model(model_path)

load_model(model_path)

__init__(builder_base: Callable[[], Model], random_seed: Optional[int] = None)[源代码]#
build_dataset_from_csv(csv_file_path: str, label: str, sampling_rate=None, shuffle=False, random_seed=1234, na_value='?', repeat_count=1, sample_length=0, buffer_size=None, ignore_errors=True, prefetch_buffer_size=None, stage='train', label_decoder=None)[源代码]#

build tf.data.Dataset

参数:
  • csv_file_path – Dict of csv file path

  • label – label column name

  • sampling_rate – Sampling rate of a batch

  • shuffle – A bool that indicates whether the input should be shuffled

  • random_seed – Randomization seed to use for shuffling.

  • na_value – Additional string to recognize as NA/NaN.

  • repeat_count – num of repeats

  • sample_length – num of sample length

  • buffer_size – shuffle size

  • ignore_errors – if True, ignores errors with CSV file parsing,

  • prefetch_buffer_size – An int specifying the number of feature batches to prefetch for performance improvement.

  • stage – the stage of the datset

  • label_decoder – callable function for label preprocess

build_dataset(x: ndarray, y: Optional[ndarray] = None, s_w: Optional[ndarray] = None, sampling_rate=None, buffer_size=None, shuffle=False, random_seed=1234, repeat_count=1, sampler_method='batch', stage='train')[源代码]#

build tf.data.Dataset

参数:
  • x – feature, FedNdArray or HDataFrame

  • y – label, FedNdArray or HDataFrame

  • s_w – sample weight of this dataset

  • sampling_rate – Sampling rate of a batch

  • buffer_size – shuffle size

  • shuffle – A bool that indicates whether the input should be shuffled

  • random_seed – Prg seed for shuffling

  • repeat_count – num of repeats

  • sampler_method – method of sampler

build_dataset_from_builder(dataset_builder: Callable, x: Union[DataFrame, str], y: Optional[ndarray] = None, s_w: Optional[ndarray] = None, repeat_count=1, stage='train')[源代码]#

build tf.data.Dataset

参数:
  • dataset_builder – Function of how to build dataset, must return dataset and step_per_epoch

  • x – A string representing the path to a CSV file or data folder containing the input data.

  • y – label, FedNdArray or HDataFrame

  • s_w – Default None, all samples are assumed to have equal weight.

  • repeat_count – An integer specifying the number of times to repeat the dataset. This is useful for increasing the effective size of the dataset.

  • stage – A string specifying the stage of the dataset to build. This is useful for separating training, validation, and test datasets.

返回:

A tensorflow dataset

get_rows_count(filename)[源代码]#
get_weights()[源代码]#
set_weights(weights)[源代码]#

set weights of client model

set_validation_metrics(global_metrics)[源代码]#
wrap_local_metrics()[源代码]#
evaluate(evaluate_steps=0)[源代码]#
predict(predict_steps=0)[源代码]#
init_training(callbacks, epochs=1, steps=0, verbose=0)[源代码]#
on_train_begin()[源代码]#
on_epoch_begin(epoch)[源代码]#
on_epoch_end(epoch)[源代码]#
on_train_end()[源代码]#
get_stop_training()[源代码]#
abstract train_step(weights, cur_steps, train_steps, **kwargs)[源代码]#
save_model(model_path: str)[源代码]#
load_model(model_path: str)[源代码]#

secretflow.ml.nn.fl.backend.tensorflow.sampler#

Functions:

batch_sampler(x, y, s_w, sampling_rate, ...)

implementation of batch sampler

possion_sampler(x, y, s_w, sampling_rate, ...)

implementation of possion sampler

sampler_data([sampler_method, x, y, s_w, ...])

do sample data by sampler_method

secretflow.ml.nn.fl.backend.tensorflow.sampler.batch_sampler(x, y, s_w, sampling_rate, buffer_size, shuffle, repeat_count, random_seed)[源代码]#

implementation of batch sampler

参数:
  • x – feature, FedNdArray or HDataFrame

  • y – label, FedNdArray or HDataFrame

  • s_w – sample weight of this dataset

  • sampling_rate – Sampling rate of a batch

  • buffer_size – shuffle size

  • shuffle – A bool that indicates whether the input should be shuffled

  • repeat_count – num of repeats

  • random_seed – Prg seed for shuffling

返回:

tf.data.Dataset

返回类型:

data_set

secretflow.ml.nn.fl.backend.tensorflow.sampler.possion_sampler(x, y, s_w, sampling_rate, random_seed)[源代码]#

implementation of possion sampler

参数:
  • x – feature, FedNdArray or HDataFrame

  • y – label, FedNdArray or HDataFrame

  • s_w – sample weight of this dataset

  • sampling_rate – Sampling rate of a batch

  • random_seed – Prg seed for shuffling

返回:

tf.data.Dataset

返回类型:

data_set

secretflow.ml.nn.fl.backend.tensorflow.sampler.sampler_data(sampler_method='batch', x=None, y=None, s_w=None, sampling_rate=None, buffer_size=None, shuffle=False, repeat_count=1, random_seed=1234)[源代码]#

do sample data by sampler_method

参数:
  • x – feature, FedNdArray or HDataFrame

  • y – label, FedNdArray or HDataFrame

  • s_w – sample weight of this dataset

  • sampling_rate – Sampling rate of a batch

  • buffer_size – shuffle size

  • shuffle – A bool that indicates whether the input should be shuffled

  • repeat_count – num of repeats

  • random_seed – Prg seed for shuffling

返回:

tf.data.Dataset

返回类型:

data_set