raiwidgets package¶
Package for the fairness, explanation, and error analysis widgets.
- class raiwidgets.ErrorAnalysisDashboard(explanation=None, model=None, *, dataset=None, true_y=None, classes=None, features=None, port=None, locale=None, public_ip=None, categorical_features=None, true_y_dataset=None, pred_y=None, pred_y_dataset=None, model_task=ModelTask.UNKNOWN, metric=None, max_depth=3, num_leaves=31, min_child_samples=20, sample_dataset=None)[source]¶
Bases:
Dashboard
ErrorAnalysis Dashboard Class.
- Parameters:
explanation (ExplanationMixin) – An object that represents an explanation.
model (object) – An object that represents a model. It is assumed that for the classification case it has a method of predict_proba() returning the prediction probabilities for each class and for the regression case a method of predict() returning the prediction value.
dataset (pd.DataFrame or numpy.ndarray or list[][]) – A matrix of feature vector examples (# examples x # features), the same samples used to build the explanation. Overwrites any existing dataset on the explanation object.
true_y (numpy.ndarray or list[]) – The true labels for the provided explanation. Overwrites any existing dataset on the explanation object. Note if explanation is sample of dataset, you will need to specify true_y_dataset as well.
classes (numpy.ndarray or list[]) – The class names.
features (numpy.ndarray or list[]) – Feature names.
port (int) – The port to use on locally hosted service.
public_ip (str) – Optional. If running on a remote vm, the external public ip address of the VM.
categorical_features (list[str]) – The categorical feature names.
true_y_dataset (numpy.ndarray or list[]) – The true labels for the provided dataset. Only needed if the explanation has a sample of instances from the original dataset. Otherwise specify true_y parameter only.
pred_y (numpy.ndarray or list[]) – The predicted y values, can be passed in as an alternative to the model and explanation for a more limited view.
pred_y_dataset (numpy.ndarray or list[] or pandas.Series) – The predicted labels for the provided dataset. Only needed if providing a sample dataset for the UI while using the full dataset for the tree view and heatmap. Otherwise specify pred_y parameter only.
model_task (str) – Optional parameter to specify whether the model is a classification or regression model. In most cases, the type of the model can be inferred based on the shape of the output, where a classifier has a predict_proba method and outputs a 2 dimensional array, while a regressor has a predict method and outputs a 1 dimensional array.
metric (str) – The metric name to evaluate at each tree node or heatmap grid. Currently supported classification metrics include ‘error_rate’, ‘recall_score’ for binary classification and ‘micro_recall_score’ or ‘macro_recall_score’ for multiclass classification, ‘precision_score’ for binary classification and ‘micro_precision_score’ or ‘macro_precision_score’ for multiclass classification, ‘f1_score’ for binary classification and ‘micro_f1_score’ or ‘macro_f1_score’ for multiclass classification, and ‘accuracy_score’. Supported regression metrics include ‘mean_absolute_error’, ‘mean_squared_error’, ‘r2_score’, and ‘median_absolute_error’.
max_depth (int) – The maximum depth of the surrogate tree trained on errors.
num_leaves (int) – The number of leaves of the surrogate tree trained on errors.
min_child_samples (int) – The minimal number of data required to create one leaf.
sample_dataset (pd.DataFrame or numpy.ndarray or list[][]) – Dataset with fewer samples than the main dataset. Used to improve performance only when an Explanation object is not provided. Used only if explanation is not specified for the dataset explorer. Specify less than 10k points for optimal performance.
locale (str) – The language in which user wants to load and access the ErrorAnalysis Dashboard. The default language is english (“en”).
- Example:
Run simple view of error analysis with just predictions and true labels
>>> predictions = model.predict(X_test) >>> from raiwidgets import ErrorAnalysisDashboard >>> ErrorAnalysisDashboard(dataset=X_test, true_y=y_test, ... features=features, pred_y=predictions)
- Example:
Run error analysis with a model and a computed explanation
>>> from raiwidgets import ErrorAnalysisDashboard >>> ErrorAnalysisDashboard(global_explanation, model, ... dataset=X_test, true_y=y_test)
- Example:
Run error analysis on large data and a downsampled dataset for the UI
>>> from raiwidgets import ErrorAnalysisDashboard >>> ErrorAnalysisDashboard(sample_dataset=X_test_sample, ... dataset=X_test, ... features=features, ... true_y=y_test_sample, ... true_y_dataset=y_test, ... pred_y=X_test_sample_pred_y, ... pred_y_dataset=X_test_pred_y)
- class raiwidgets.ExplanationDashboard(explanation, model=None, dataset=None, true_y=None, classes=None, features=None, public_ip=None, port=None, locale=None)[source]¶
Bases:
Dashboard
The dashboard class, wraps the dashboard component.
- Parameters:
explanation (ExplanationMixin) – An object that represents an explanation.
model (object) – An object that represents a model. It is assumed that for the classification case flit has a method of predict_proba() returning the prediction probabilities for each class and for the regression case a method of predict() returning the prediction value.
dataset (numpy.ndarray or list[][]) – A matrix of feature vector examples (# examples x # features), the same samples used to build the explanation. Overwrites any existing dataset on the explanation object. Must have fewer than 100000 rows and fewer than 1000 columns. Note dashboard may become slow or crash for more than 10000 rows.
true_y (numpy.ndarray or list[]) – The true labels for the provided dataset. Overwrites any existing dataset on the explanation object.
classes (numpy.ndarray or list[]) – The class names.
features (numpy.ndarray or list[]) – Feature names.
public_ip (str) – Optional. If running on a remote vm, the external public ip address of the VM.
port (int) – The port to use on locally hosted service.
locale (str) – The language in which user wants to load and access the Explanation Dashboard. The default language is english (“en”).
- class raiwidgets.FairnessDashboard(*, sensitive_features, y_true, y_pred, locale=None, public_ip=None, port=None, fairness_metric_module=None, fairness_metric_mapping=None)[source]¶
Bases:
Dashboard
The dashboard class, wraps the dashboard component.
- Parameters:
sensitive_features (pandas.Series, pandas.DataFrame, list, Dict[str,1d array] or something convertible to numpy.ndarray) – The sensitive features These can be from the initial dataset, or reserved from training. If the input type provides names, they will be used. Otherwise, names of “Sensitive Feature <n>” are generated
y_true (numpy.ndarray or list[]) – The true labels or values for the provided dataset.
y_pred (pandas.Series, pandas.DataFrame, list, Dict[str,1d array] or something convertible to numpy.ndarray) – Array of output predictions from models to be evaluated. If the input type provides names, they will be used. Otherwise, names of “Model <n>” are generated
locale (str) – The language in which user wants to load and access the Fairness Dashboard. The default language is english (“en”).
public_ip (str) – Optional. If running on a remote vm, the external public ip address of the VM.
port (int) – The port to use on locally hosted service.
- class raiwidgets.ModelAnalysisDashboard(analysis: RAIInsights, public_ip=None, port=None, locale=None)[source]¶
Bases:
object
The dashboard class, wraps the dashboard component.
Note: this class is now deprecated, please use the ResponsibleAIDashboard instead.
- Parameters:
analysis (RAIInsights) – An object that represents an model analysis.
public_ip (str) – Optional. If running on a remote vm, the external public ip address of the VM.
port (int) – The port to use on locally hosted service.
locale (str) – The language in which user wants to load and access the ModelAnalysis Dashboard. The default language is english (“en”).
- class raiwidgets.ModelPerformanceDashboard(model=None, dataset=None, true_y=None, classes=None, features=None, public_ip=None, port=None, locale=None)[source]¶
Bases:
Dashboard
The dashboard class, wraps the dashboard component.
- Parameters:
model (object) – An object that represents a model. It is assumed that for the classification case flit has a method of predict_proba() returning the prediction probabilities for each class and for the regression case a method of predict() returning the prediction value.
dataset (numpy.ndarray or list[][]) – A matrix of feature vector examples (# examples x # features), the same samples used to build the explanation. Overwrites any existing dataset on the explanation object. Must have fewer than 10000 rows and fewer than 1000 columns.
true_y (numpy.ndarray or list[]) – The true labels for the provided dataset. Overwrites any existing dataset on the explanation object.
classes (numpy.ndarray or list[]) – The class names.
features (numpy.ndarray or list[]) – Feature names.
public_ip (str) – Optional. If running on a remote vm, the external public ip address of the VM.
port (int) – The port to use on locally hosted service.
locale (str) – The language in which user wants to load and access the ModelPerformance Dashboard. The default language is english (“en”).
- class raiwidgets.ResponsibleAIDashboard(analysis: RAIInsights, public_ip=None, port=None, locale=None, cohort_list=None, is_private_link=False, **kwargs)[source]¶
Bases:
Dashboard
The dashboard class, wraps the dashboard component. :param analysis: An object that represents an RAIInsights. :type analysis: RAIInsights :param public_ip: Optional. If running on a remote vm,
the external public ip address of the VM.
- Parameters:
port (int) – The port to use on locally hosted service.
locale (str) – The language in which user wants to load and access the ResponsibleAI Dashboard. The default language is english (“en”).
cohort_list (List[Cohort]) – List of cohorts defined by the user for the dashboard.
is_private_link (bool) – If the dashboard environment is a private link AML workspace.