Skip to content

Model

YOLO

YOLO

A python interface which emulates a model-like behaviour by wrapping trainers.

Source code in ultralytics/yolo/engine/model.py
class YOLO:
    """
    YOLO

    A python interface which emulates a model-like behaviour by wrapping trainers.
    """

    def __init__(self, model='yolov8n.yaml', type="v8") -> None:
        """
        Initializes the YOLO object.

        Args:
            model (str, Path): model to load or create
            type (str): Type/version of models to use. Defaults to "v8".
        """
        self.type = type
        self.ModelClass = None  # model class
        self.TrainerClass = None  # trainer class
        self.ValidatorClass = None  # validator class
        self.PredictorClass = None  # predictor class
        self.predictor = None  # reuse predictor
        self.model = None  # model object
        self.trainer = None  # trainer object
        self.task = None  # task type
        self.ckpt = None  # if loaded from *.pt
        self.cfg = None  # if loaded from *.yaml
        self.ckpt_path = None
        self.overrides = {}  # overrides for trainer object

        # Load or create new YOLO model
        load_methods = {'.pt': self._load, '.yaml': self._new}
        suffix = Path(model).suffix
        if suffix in load_methods:
            {'.pt': self._load, '.yaml': self._new}[suffix](model)
        else:
            raise NotImplementedError(f"'{suffix}' model loading not implemented")

    def __call__(self, source=None, stream=False, **kwargs):
        return self.predict(source, stream, **kwargs)

    def _new(self, cfg: str, verbose=True):
        """
        Initializes a new model and infers the task type from the model definitions.

        Args:
            cfg (str): model configuration file
            verbose (bool): display model info on load
        """
        cfg = check_yaml(cfg)  # check YAML
        cfg_dict = yaml_load(cfg, append_filename=True)  # model dict
        self.task = guess_model_task(cfg_dict)
        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
            self._assign_ops_from_task(self.task)
        self.model = self.ModelClass(cfg_dict, verbose=verbose)  # initialize
        self.cfg = cfg

    def _load(self, weights: str):
        """
        Initializes a new model and infers the task type from the model head.

        Args:
            weights (str): model checkpoint to be loaded
        """
        self.model, self.ckpt = attempt_load_one_weight(weights)
        self.ckpt_path = weights
        self.task = self.model.args["task"]
        self.overrides = self.model.args
        self._reset_ckpt_args(self.overrides)
        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
            self._assign_ops_from_task(self.task)

    def reset(self):
        """
        Resets the model modules.
        """
        for m in self.model.modules():
            if hasattr(m, 'reset_parameters'):
                m.reset_parameters()
        for p in self.model.parameters():
            p.requires_grad = True

    def info(self, verbose=False):
        """
        Logs model info.

        Args:
            verbose (bool): Controls verbosity.
        """
        self.model.info(verbose=verbose)

    def fuse(self):
        self.model.fuse()

    def predict(self, source=None, stream=False, **kwargs):
        """
        Perform prediction using the YOLO model.

        Args:
            source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
                          Accepts all source types accepted by the YOLO model.
            stream (bool): Whether to stream the predictions or not. Defaults to False.
            **kwargs : Additional keyword arguments passed to the predictor.
                       Check the 'configuration' section in the documentation for all available options.

        Returns:
            (List[ultralytics.yolo.engine.results.Results]): The prediction results.
        """
        overrides = self.overrides.copy()
        overrides["conf"] = 0.25
        overrides.update(kwargs)
        overrides["mode"] = "predict"
        overrides["save"] = kwargs.get("save", False)  # not save files by default
        if not self.predictor:
            self.predictor = self.PredictorClass(overrides=overrides)
            self.predictor.setup_model(model=self.model)
        else:  # only update args if predictor is already setup
            self.predictor.args = get_cfg(self.predictor.args, overrides)
        is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
        return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)

    @smart_inference_mode()
    def val(self, data=None, **kwargs):
        """
        Validate a model on a given dataset .

        Args:
            data (str): The dataset to validate on. Accepts all formats accepted by yolo
            **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
        """
        overrides = self.overrides.copy()
        overrides["rect"] = True  # rect batches as default
        overrides.update(kwargs)
        overrides["mode"] = "val"
        args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
        args.data = data or args.data
        args.task = self.task
        if args.imgsz == DEFAULT_CFG.imgsz:
            args.imgsz = self.model.args['imgsz']  # use trained imgsz unless custom value is passed
        args.imgsz = check_imgsz(args.imgsz, max_dim=1)

        validator = self.ValidatorClass(args=args)
        validator(model=self.model)
        return validator.metrics

    @smart_inference_mode()
    def export(self, **kwargs):
        """
        Export model.

        Args:
            **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
        """

        overrides = self.overrides.copy()
        overrides.update(kwargs)
        args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
        args.task = self.task
        if args.imgsz == DEFAULT_CFG.imgsz:
            args.imgsz = self.model.args['imgsz']  # use trained imgsz unless custom value is passed

        exporter = Exporter(overrides=args)
        exporter(model=self.model)

    def train(self, **kwargs):
        """
        Trains the model on a given dataset.

        Args:
            **kwargs (Any): Any number of arguments representing the training configuration.
        """
        overrides = self.overrides.copy()
        overrides.update(kwargs)
        if kwargs.get("cfg"):
            LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
            overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
        overrides["task"] = self.task
        overrides["mode"] = "train"
        if not overrides.get("data"):
            raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
        if overrides.get("resume"):
            overrides["resume"] = self.ckpt_path

        self.trainer = self.TrainerClass(overrides=overrides)
        if not overrides.get("resume"):  # manually set model only if not resuming
            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
            self.model = self.trainer.model
        self.trainer.train()
        # update model and cfg after training
        if RANK in {0, -1}:
            self.model, _ = attempt_load_one_weight(str(self.trainer.best))
            self.overrides = self.model.args

    def to(self, device):
        """
        Sends the model to the given device.

        Args:
            device (str): device
        """
        self.model.to(device)

    def _assign_ops_from_task(self, task):
        model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
        # warning: eval is unsafe. Use with caution
        trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
        validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
        predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))

        return model_class, trainer_class, validator_class, predictor_class

    @property
    def names(self):
        """
         Returns class names of the loaded model.
        """
        return self.model.names

    @property
    def transforms(self):
        """
         Returns transform of the loaded model.
        """
        return self.model.transforms if hasattr(self.model, 'transforms') else None

    @staticmethod
    def add_callback(event: str, func):
        """
        Add callback
        """
        callbacks.default_callbacks[event].append(func)

    @staticmethod
    def _reset_ckpt_args(args):
        for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
                'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
            args.pop(arg, None)

names property

Returns class names of the loaded model.

transforms property

Returns transform of the loaded model.

__init__(model='yolov8n.yaml', type='v8')

Initializes the YOLO object.

Parameters:

Name Type Description Default
model str, Path

model to load or create

'yolov8n.yaml'
type str

Type/version of models to use. Defaults to "v8".

'v8'
Source code in ultralytics/yolo/engine/model.py
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
    """
    Initializes the YOLO object.

    Args:
        model (str, Path): model to load or create
        type (str): Type/version of models to use. Defaults to "v8".
    """
    self.type = type
    self.ModelClass = None  # model class
    self.TrainerClass = None  # trainer class
    self.ValidatorClass = None  # validator class
    self.PredictorClass = None  # predictor class
    self.predictor = None  # reuse predictor
    self.model = None  # model object
    self.trainer = None  # trainer object
    self.task = None  # task type
    self.ckpt = None  # if loaded from *.pt
    self.cfg = None  # if loaded from *.yaml
    self.ckpt_path = None
    self.overrides = {}  # overrides for trainer object

    # Load or create new YOLO model
    load_methods = {'.pt': self._load, '.yaml': self._new}
    suffix = Path(model).suffix
    if suffix in load_methods:
        {'.pt': self._load, '.yaml': self._new}[suffix](model)
    else:
        raise NotImplementedError(f"'{suffix}' model loading not implemented")

add_callback(event, func) staticmethod

Add callback

Source code in ultralytics/yolo/engine/model.py
@staticmethod
def add_callback(event: str, func):
    """
    Add callback
    """
    callbacks.default_callbacks[event].append(func)

export(**kwargs)

Export model.

Parameters:

Name Type Description Default
**kwargs

Any other args accepted by the predictors. To see all args check 'configuration' section in docs

{}
Source code in ultralytics/yolo/engine/model.py
@smart_inference_mode()
def export(self, **kwargs):
    """
    Export model.

    Args:
        **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
    """

    overrides = self.overrides.copy()
    overrides.update(kwargs)
    args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
    args.task = self.task
    if args.imgsz == DEFAULT_CFG.imgsz:
        args.imgsz = self.model.args['imgsz']  # use trained imgsz unless custom value is passed

    exporter = Exporter(overrides=args)
    exporter(model=self.model)

info(verbose=False)

Logs model info.

Parameters:

Name Type Description Default
verbose bool

Controls verbosity.

False
Source code in ultralytics/yolo/engine/model.py
def info(self, verbose=False):
    """
    Logs model info.

    Args:
        verbose (bool): Controls verbosity.
    """
    self.model.info(verbose=verbose)

predict(source=None, stream=False, **kwargs)

Perform prediction using the YOLO model.

Parameters:

Name Type Description Default
source str | int | PIL | np.ndarray

The source of the image to make predictions on. Accepts all source types accepted by the YOLO model.

None
stream bool

Whether to stream the predictions or not. Defaults to False.

False
**kwargs

Additional keyword arguments passed to the predictor. Check the 'configuration' section in the documentation for all available options.

{}

Returns:

Type Description
List[ultralytics.yolo.engine.results.Results]

The prediction results.

Source code in ultralytics/yolo/engine/model.py
def predict(self, source=None, stream=False, **kwargs):
    """
    Perform prediction using the YOLO model.

    Args:
        source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
                      Accepts all source types accepted by the YOLO model.
        stream (bool): Whether to stream the predictions or not. Defaults to False.
        **kwargs : Additional keyword arguments passed to the predictor.
                   Check the 'configuration' section in the documentation for all available options.

    Returns:
        (List[ultralytics.yolo.engine.results.Results]): The prediction results.
    """
    overrides = self.overrides.copy()
    overrides["conf"] = 0.25
    overrides.update(kwargs)
    overrides["mode"] = "predict"
    overrides["save"] = kwargs.get("save", False)  # not save files by default
    if not self.predictor:
        self.predictor = self.PredictorClass(overrides=overrides)
        self.predictor.setup_model(model=self.model)
    else:  # only update args if predictor is already setup
        self.predictor.args = get_cfg(self.predictor.args, overrides)
    is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
    return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)

reset()

Resets the model modules.

Source code in ultralytics/yolo/engine/model.py
def reset(self):
    """
    Resets the model modules.
    """
    for m in self.model.modules():
        if hasattr(m, 'reset_parameters'):
            m.reset_parameters()
    for p in self.model.parameters():
        p.requires_grad = True

to(device)

Sends the model to the given device.

Parameters:

Name Type Description Default
device str

device

required
Source code in ultralytics/yolo/engine/model.py
def to(self, device):
    """
    Sends the model to the given device.

    Args:
        device (str): device
    """
    self.model.to(device)

train(**kwargs)

Trains the model on a given dataset.

Parameters:

Name Type Description Default
**kwargs Any

Any number of arguments representing the training configuration.

{}
Source code in ultralytics/yolo/engine/model.py
def train(self, **kwargs):
    """
    Trains the model on a given dataset.

    Args:
        **kwargs (Any): Any number of arguments representing the training configuration.
    """
    overrides = self.overrides.copy()
    overrides.update(kwargs)
    if kwargs.get("cfg"):
        LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
        overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
    overrides["task"] = self.task
    overrides["mode"] = "train"
    if not overrides.get("data"):
        raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
    if overrides.get("resume"):
        overrides["resume"] = self.ckpt_path

    self.trainer = self.TrainerClass(overrides=overrides)
    if not overrides.get("resume"):  # manually set model only if not resuming
        self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
        self.model = self.trainer.model
    self.trainer.train()
    # update model and cfg after training
    if RANK in {0, -1}:
        self.model, _ = attempt_load_one_weight(str(self.trainer.best))
        self.overrides = self.model.args

val(data=None, **kwargs)

Validate a model on a given dataset .

Parameters:

Name Type Description Default
data str

The dataset to validate on. Accepts all formats accepted by yolo

None
**kwargs

Any other args accepted by the validators. To see all args check 'configuration' section in docs

{}
Source code in ultralytics/yolo/engine/model.py
@smart_inference_mode()
def val(self, data=None, **kwargs):
    """
    Validate a model on a given dataset .

    Args:
        data (str): The dataset to validate on. Accepts all formats accepted by yolo
        **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
    """
    overrides = self.overrides.copy()
    overrides["rect"] = True  # rect batches as default
    overrides.update(kwargs)
    overrides["mode"] = "val"
    args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
    args.data = data or args.data
    args.task = self.task
    if args.imgsz == DEFAULT_CFG.imgsz:
        args.imgsz = self.model.args['imgsz']  # use trained imgsz unless custom value is passed
    args.imgsz = check_imgsz(args.imgsz, max_dim=1)

    validator = self.ValidatorClass(args=args)
    validator(model=self.model)
    return validator.metrics