Skip to content

MLFlow

This package wraps some functionalities of MLFlows life-cycle management features.

MLflow Skinny: A Lightweight Machine Learning Lifecycle Platform Client.

MLflow Skinny is a lightweight MLflow package without SQL storage, server, UI, or data science dependencies. MLflow Skinny supports: Tracking operations (logging / loading / searching params, metrics, tags + logging / loading artifacts) Model registration, search, artifact loading, and deployment Execution of GitHub projects within notebook & against a remote target.

GetLatestModelVersion

GetLatestModelVersion(tracking_uri: str, model_name: str, **kwargs: Any)

Bases: BaseOperator

Custom wrapper around MLFlowClient.search_model_versions.

Parameters:

Name Type Description Default

tracking_uri

str

MLFlow tracking server uri

required

model_name

str

name of MLFlow Model

required

kwargs

Any

BaseOperator kwargs

{}

Returns:

Type Description
None

(dict) {"version": latest_version.version, "run_id": latest_version.run_id}

GetModelMetaData

GetModelMetaData(tracking_uri: str, upstream_task_id: str, **kwargs: Any)

Bases: BaseOperator

Wrapper around MLFlowClient.get_run.

Parameters:

Name Type Description Default

tracking_uri

str

Mlflow.client tracking URI

required

upstream_task_id

str

xcom which has run_id key.

required

kwargs

Any

BaseOperator args

{}

GetModelVersionByNameAndStage

GetModelVersionByNameAndStage(tracking_uri: str, model_name: str, stage: MlflowModelStage, **kwargs: Any)

Bases: BaseOperator

Wrapper around MLflow search_model_versions limited by name and tag.stage.

Example
# Get prod model run id
from dagcellent.operators.mlflow import (
    GetModelVersionByNameAndStage,
    MlflowModelStage,
)

get_prod_run_id = GetModelVersionByNameAndStage(
    task_id="get_prod_run_id",
    tracking_uri="<login>:<password>@<domain>/<login>/mlflow",
    model_name="<skynet-auto>",
    stage=MlflowModelStage.PRODUCTION,
)

If no stage provided, returns the latest version for each stage.

Parameters:

Name Type Description Default

tracking_uri

str

Mlflow.client tracking URI

required

model_name

str

MLFlow model name

required

stage

MlflowModelStage

target stage to transition to

required

kwargs

Any

BaseOperator args

{}

MlflowModelStage

Bases: Enum

Allowed Mlflow 'stage' tag values.

SetModelVersionTag

SetModelVersionTag(tracking_uri: str, model_name: str, version: str, tag: dict[str, str], **kwargs: Any)

Bases: BaseOperator

Wrapper around MLFlowClient.set_model_version_tag.

Tags have to be passed in as key-value pairs. Multiple tags can be set.

Parameters:

Name Type Description Default

tracking_uri

str

Mlflow.client tracking URI

required

model_name

str

mlfow model name e.g.: batch

required

version

str

(str): model version

required

tag

dict[str, str]

arbitrary tag in {key: value} format

required

kwargs

Any

BaseOperator args

{}

SlimModelVersion

Bases: TypedDict

Slim, JSON serializable type of mlflow.entities.model_registry.ModelVersion.