Getting started

Tutorials

Walkthroughs for each task type, end-to-end.

A tutorial per task type. Each one is self-contained — you can paste the code into a fresh notebook and follow along. They use sample datasets shipped with pycaret.datasets so there's nothing to download.

Classification — predicting customer purchase#

from pycaret.classification import ClassificationExperiment
from pycaret.datasets import get_data

data = get_data("juice", verbose=False)

exp = ClassificationExperiment(
    target="Purchase",
    session_id=42,
    normalize=True,
).fit(data)

# Compare 12 models. Returns a CompareResult.
top = exp.compare_models(n_select=3)
print(top.leaderboard.head())

# Tune the winner.
tuned = exp.tune_model(top.best, n_iter=20, optimize="AUC")
print("best params:", tuned.best_params)

# Inspect the diagnostics.
preds = exp.predict_model(tuned.pipeline)
print(preds.metrics)

# Persist for production.
final = exp.finalize_model(tuned.pipeline)
exp.save_model(final.pipeline, "juice-classifier")

Regression — Boston housing prices#

from pycaret.regression import RegressionExperiment
from pycaret.datasets import get_data

data = get_data("boston", verbose=False)

exp = RegressionExperiment(
    target="medv",
    session_id=42,
    normalize=True,
    transformation=True,
).fit(data)

best = exp.compare_models(sort="RMSE").best
tuned = exp.tune_model(best, n_iter=20, optimize="RMSE")

# Residual diagnostics — Plotly figure ready for fig.show() / API serving.
from pycaret.plots.regression import residuals, prediction_error
residuals(tuned.pipeline, exp.X_test, exp.y_test).show()
prediction_error(tuned.pipeline, exp.X_test, exp.y_test).show()

Clustering — segmenting jewellery customers#

from pycaret.clustering import ClusteringExperiment
from pycaret.datasets import get_data
from pycaret.plots.clustering import elbow_curve, silhouette_curve, embedding_2d

data = get_data("jewellery", verbose=False)

exp = ClusteringExperiment(session_id=42, normalize=True).fit(data)

# Pick k via the elbow + silhouette charts.
elbow_curve(exp.create_model("kmeans").pipeline, exp._fit_state["X_transformed"]).show()
silhouette_curve(exp.create_model("kmeans").pipeline, exp._fit_state["X_transformed"]).show()

# Final model with the chosen k.
res = exp.create_model("kmeans", num_clusters=4)
labelled = exp.assign_model(res.pipeline)  # original df + Cluster column

embedding_2d(res.pipeline, exp._fit_state["X_transformed"]).show()

Anomaly detection — identifying outliers#

from pycaret.anomaly import AnomalyExperiment
from pycaret.datasets import get_data
from pycaret.plots.anomaly import score_distribution, anomaly_map

data = get_data("anomaly", verbose=False)

exp = AnomalyExperiment(session_id=42).fit(data)
res = exp.create_model("iforest")  # IsolationForest
labelled = exp.assign_model(res.pipeline)
print(labelled.head())  # Anomaly + Anomaly_Score columns attached

score_distribution(res.pipeline, exp._fit_state["X_transformed"]).show()
anomaly_map(res.pipeline, exp._fit_state["X_transformed"]).show()

Time-series — airline passengers#

from pycaret.time_series import TimeSeriesExperiment
from pycaret.datasets import get_data
from pycaret.plots.time_series import (
    forecast, decomposition, residual_diagnostics
)

data = get_data("airline", verbose=False)

exp = TimeSeriesExperiment(fh=12, session_id=42).fit(data)
top = exp.compare_models(include=["arima", "ets", "theta", "naive"], sort="MASE")
print(top.leaderboard)

best = exp.tune_model(top.best, n_iter=10, optimize="MASE").pipeline
preds = exp.predict_model(best, return_pred_int=True).predictions

# Forecast + diagnostics.
forecast(
    y_true=exp.y_test,
    y_pred=preds["y_pred"],
    lower=preds["lower"],
    upper=preds["upper"],
    history=exp.y_train,
).show()
decomposition(exp.y_train, period=12).show()
residual_diagnostics(exp.y_test, preds["y_pred"]).show()

What to do next#

  • Read Modules for an overview of all five task modules and their differences.
  • Skim Functions / Initialize to understand the Experiment(...).fit() constructor in depth.
  • Browse API reference for every public symbol.