Skip to content

nn Module

Ultralytics nn module contains 3 main components:

  1. AutoBackend: A module that can run inference on all popular model formats
  2. BaseModel: BaseModel class defines the operations supported by tasks like Detection and Segmentation
  3. modules: Optimized and reusable neural network blocks built on PyTorch.

AutoBackend

Bases: nn.Module

Source code in ultralytics/nn/autobackend.py
class AutoBackend(nn.Module):

    def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
        """
        MultiBackend class for python inference on various platforms using Ultralytics YOLO.

        Args:
            weights (str): The path to the weights file. Default: 'yolov8n.pt'
            device (torch.device): The device to run the model on.
            dnn (bool): Use OpenCV's DNN module for inference if True, defaults to False.
            data (dict): Additional data, optional
            fp16 (bool): If True, use half precision. Default: False
            fuse (bool): Whether to fuse the model or not. Default: True

        Supported formats and their naming conventions:
            | Format                | Suffix           |
            |-----------------------|------------------|
            | PyTorch               | *.pt             |
            | TorchScript           | *.torchscript    |
            | ONNX Runtime          | *.onnx           |
            | ONNX OpenCV DNN       | *.onnx --dnn     |
            | OpenVINO              | *.xml            |
            | CoreML                | *.mlmodel        |
            | TensorRT              | *.engine         |
            | TensorFlow SavedModel | *_saved_model    |
            | TensorFlow GraphDef   | *.pb             |
            | TensorFlow Lite       | *.tflite         |
            | TensorFlow Edge TPU   | *_edgetpu.tflite |
            | PaddlePaddle          | *_paddle_model   |
        """
        super().__init__()
        w = str(weights[0] if isinstance(weights, list) else weights)
        nn_module = isinstance(weights, torch.nn.Module)
        pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
        fp16 &= pt or jit or onnx or engine or nn_module  # FP16
        nhwc = coreml or saved_model or pb or tflite or edgetpu  # BHWC formats (vs torch BCWH)
        stride = 32  # default stride
        model = None  # TODO: resolves ONNX inference, verify effect on other backends
        cuda = torch.cuda.is_available() and device.type != 'cpu'  # use CUDA
        if not (pt or triton or nn_module):
            w = attempt_download_asset(w)  # download if not local

        # NOTE: special case: in-memory pytorch model
        if nn_module:
            model = weights.to(device)
            model = model.fuse() if fuse else model
            names = model.module.names if hasattr(model, 'module') else model.names  # get class names
            stride = max(int(model.stride.max()), 32)  # model stride
            model.half() if fp16 else model.float()
            self.model = model  # explicitly assign for to(), cpu(), cuda(), half()
            pt = True
        elif pt:  # PyTorch
            from ultralytics.nn.tasks import attempt_load_weights
            model = attempt_load_weights(weights if isinstance(weights, list) else w,
                                         device=device,
                                         inplace=True,
                                         fuse=fuse)
            stride = max(int(model.stride.max()), 32)  # model stride
            names = model.module.names if hasattr(model, 'module') else model.names  # get class names
            model.half() if fp16 else model.float()
            self.model = model  # explicitly assign for to(), cpu(), cuda(), half()
        elif jit:  # TorchScript
            LOGGER.info(f'Loading {w} for TorchScript inference...')
            extra_files = {'config.txt': ''}  # model metadata
            model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
            model.half() if fp16 else model.float()
            if extra_files['config.txt']:  # load metadata dict
                d = json.loads(extra_files['config.txt'],
                               object_hook=lambda d: {int(k) if k.isdigit() else k: v
                                                      for k, v in d.items()})
                stride, names = int(d['stride']), d['names']
        elif dnn:  # ONNX OpenCV DNN
            LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
            check_requirements('opencv-python>=4.5.4')
            net = cv2.dnn.readNetFromONNX(w)
        elif onnx:  # ONNX Runtime
            LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
            check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
            import onnxruntime
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
            session = onnxruntime.InferenceSession(w, providers=providers)
            output_names = [x.name for x in session.get_outputs()]
            meta = session.get_modelmeta().custom_metadata_map  # metadata
            if 'stride' in meta:
                stride, names = int(meta['stride']), eval(meta['names'])
        elif xml:  # OpenVINO
            LOGGER.info(f'Loading {w} for OpenVINO inference...')
            check_requirements('openvino')  # requires openvino-dev: https://pypi.org/project/openvino-dev/
            from openvino.runtime import Core, Layout, get_batch  # noqa
            ie = Core()
            if not Path(w).is_file():  # if not *.xml
                w = next(Path(w).glob('*.xml'))  # get *.xml file from *_openvino_model dir
            network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
            if network.get_parameters()[0].get_layout().empty:
                network.get_parameters()[0].set_layout(Layout("NCHW"))
            batch_dim = get_batch(network)
            if batch_dim.is_static:
                batch_size = batch_dim.get_length()
            executable_network = ie.compile_model(network, device_name="CPU")  # device_name="MYRIAD" for Intel NCS2
            stride, names = self._load_metadata(Path(w).with_suffix('.yaml'))  # load metadata
        elif engine:  # TensorRT
            LOGGER.info(f'Loading {w} for TensorRT inference...')
            import tensorrt as trt  # https://developer.nvidia.com/nvidia-tensorrt-download
            check_version(trt.__version__, '7.0.0', hard=True)  # require tensorrt>=7.0.0
            if device.type == 'cpu':
                device = torch.device('cuda:0')
            Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
            logger = trt.Logger(trt.Logger.INFO)
            with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
                model = runtime.deserialize_cuda_engine(f.read())
            context = model.create_execution_context()
            bindings = OrderedDict()
            output_names = []
            fp16 = False  # default updated below
            dynamic = False
            for i in range(model.num_bindings):
                name = model.get_binding_name(i)
                dtype = trt.nptype(model.get_binding_dtype(i))
                if model.binding_is_input(i):
                    if -1 in tuple(model.get_binding_shape(i)):  # dynamic
                        dynamic = True
                        context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
                    if dtype == np.float16:
                        fp16 = True
                else:  # output
                    output_names.append(name)
                shape = tuple(context.get_binding_shape(i))
                im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
                bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
            binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
            batch_size = bindings['images'].shape[0]  # if dynamic, this is instead max batch size
        elif coreml:  # CoreML
            LOGGER.info(f'Loading {w} for CoreML inference...')
            import coremltools as ct
            model = ct.models.MLModel(w)
        elif saved_model:  # TF SavedModel
            LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
            import tensorflow as tf
            keras = False  # assume TF1 saved_model
            model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
        elif pb:  # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
            LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
            import tensorflow as tf

            def wrap_frozen_graph(gd, inputs, outputs):
                x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped
                ge = x.graph.as_graph_element
                return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))

            def gd_outputs(gd):
                name_list, input_list = [], []
                for node in gd.node:  # tensorflow.core.framework.node_def_pb2.NodeDef
                    name_list.append(node.name)
                    input_list.extend(node.input)
                return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))

            gd = tf.Graph().as_graph_def()  # TF GraphDef
            with open(w, 'rb') as f:
                gd.ParseFromString(f.read())
            frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
        elif tflite or edgetpu:  # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
            try:  # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
                from tflite_runtime.interpreter import Interpreter, load_delegate
            except ImportError:
                import tensorflow as tf
                Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
            if edgetpu:  # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
                LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
                delegate = {
                    'Linux': 'libedgetpu.so.1',
                    'Darwin': 'libedgetpu.1.dylib',
                    'Windows': 'edgetpu.dll'}[platform.system()]
                interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
            else:  # TFLite
                LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
                interpreter = Interpreter(model_path=w)  # load TFLite model
            interpreter.allocate_tensors()  # allocate
            input_details = interpreter.get_input_details()  # inputs
            output_details = interpreter.get_output_details()  # outputs
        elif tfjs:  # TF.js
            raise NotImplementedError('ERROR: YOLOv8 TF.js inference is not supported')
        elif paddle:  # PaddlePaddle
            LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
            check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
            import paddle.inference as pdi
            if not Path(w).is_file():  # if not *.pdmodel
                w = next(Path(w).rglob('*.pdmodel'))  # get *.xml file from *_openvino_model dir
            weights = Path(w).with_suffix('.pdiparams')
            config = pdi.Config(str(w), str(weights))
            if cuda:
                config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
            predictor = pdi.create_predictor(config)
            input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
            output_names = predictor.get_output_names()
        elif triton:  # NVIDIA Triton Inference Server
            LOGGER.info('Triton Inference Server not supported...')
            '''
            TODO:
            check_requirements('tritonclient[all]')
            from utils.triton import TritonRemoteModel
            model = TritonRemoteModel(url=w)
            nhwc = model.runtime.startswith("tensorflow")
            '''
        else:
            raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
                                      f"https://docs.ultralytics.com/reference/nn/")

        # class names
        if 'names' not in locals():  # names missing
            names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}  # assign default
        names = check_class_names(names)

        self.__dict__.update(locals())  # assign all variables to self

    def forward(self, im, augment=False, visualize=False):
        """
        Runs inference on the YOLOv8 MultiBackend model.

        Args:
            im (torch.Tensor): The image tensor to perform inference on.
            augment (bool): whether to perform data augmentation during inference, defaults to False
            visualize (bool): whether to visualize the output predictions, defaults to False

        Returns:
            (tuple): Tuple containing the raw output tensor, and the processed output for visualization (if visualize=True)
        """
        b, ch, h, w = im.shape  # batch, channel, height, width
        if self.fp16 and im.dtype != torch.float16:
            im = im.half()  # to FP16
        if self.nhwc:
            im = im.permute(0, 2, 3, 1)  # torch BCHW to numpy BHWC shape(1,320,192,3)

        if self.pt or self.nn_module:  # PyTorch
            y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
        elif self.jit:  # TorchScript
            y = self.model(im)
        elif self.dnn:  # ONNX OpenCV DNN
            im = im.cpu().numpy()  # torch to numpy
            self.net.setInput(im)
            y = self.net.forward()
        elif self.onnx:  # ONNX Runtime
            im = im.cpu().numpy()  # torch to numpy
            y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
        elif self.xml:  # OpenVINO
            im = im.cpu().numpy()  # FP32
            y = list(self.executable_network([im]).values())
        elif self.engine:  # TensorRT
            if self.dynamic and im.shape != self.bindings['images'].shape:
                i = self.model.get_binding_index('images')
                self.context.set_binding_shape(i, im.shape)  # reshape if dynamic
                self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
                for name in self.output_names:
                    i = self.model.get_binding_index(name)
                    self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
            s = self.bindings['images'].shape
            assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
            self.binding_addrs['images'] = int(im.data_ptr())
            self.context.execute_v2(list(self.binding_addrs.values()))
            y = [self.bindings[x].data for x in sorted(self.output_names)]
        elif self.coreml:  # CoreML
            im = im.cpu().numpy()
            im = Image.fromarray((im[0] * 255).astype('uint8'))
            # im = im.resize((192, 320), Image.ANTIALIAS)
            y = self.model.predict({'image': im})  # coordinates are xywh normalized
            if 'confidence' in y:
                box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels
                conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
                y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
            else:
                y = list(reversed(y.values()))  # reversed for segmentation models (pred, proto)
        elif self.paddle:  # PaddlePaddle
            im = im.cpu().numpy().astype(np.float32)
            self.input_handle.copy_from_cpu(im)
            self.predictor.run()
            y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
        elif self.triton:  # NVIDIA Triton Inference Server
            y = self.model(im)
        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
            im = im.cpu().numpy()
            if self.saved_model:  # SavedModel
                y = self.model(im, training=False) if self.keras else self.model(im)
            elif self.pb:  # GraphDef
                y = self.frozen_func(x=self.tf.constant(im))
            else:  # Lite or Edge TPU
                input = self.input_details[0]
                int8 = input['dtype'] == np.uint8  # is TFLite quantized uint8 model
                if int8:
                    scale, zero_point = input['quantization']
                    im = (im / scale + zero_point).astype(np.uint8)  # de-scale
                self.interpreter.set_tensor(input['index'], im)
                self.interpreter.invoke()
                y = []
                for output in self.output_details:
                    x = self.interpreter.get_tensor(output['index'])
                    if int8:
                        scale, zero_point = output['quantization']
                        x = (x.astype(np.float32) - zero_point) * scale  # re-scale
                    y.append(x)
            y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
            y[0][..., :4] *= [w, h, w, h]  # xywh normalized to pixels

        if isinstance(y, (list, tuple)):
            return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
        else:
            return self.from_numpy(y)

    def from_numpy(self, x):
        """
         Convert a numpy array to a tensor.

         Args:
             x (np.ndarray): The array to be converted.

         Returns:
             (torch.Tensor): The converted tensor
         """
        return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x

    def warmup(self, imgsz=(1, 3, 640, 640)):
        """
        Warm up the model by running one forward pass with a dummy input.

        Args:
            imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)

        Returns:
            (None): This method runs the forward pass and don't return any value
        """
        warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
        if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
            im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device)  # input
            for _ in range(2 if self.jit else 1):  #
                self.forward(im)  # warmup

    @staticmethod
    def _model_type(p='path/to/model.pt'):
        """
        This function takes a path to a model file and returns the model type

        Args:
            p: path to the model file. Defaults to path/to/model.pt
        """
        # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
        # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
        from ultralytics.yolo.engine.exporter import export_formats
        sf = list(export_formats().Suffix)  # export suffixes
        if not is_url(p, check=False) and not isinstance(p, str):
            check_suffix(p, sf)  # checks
        url = urlparse(p)  # if url may be Triton inference server
        types = [s in Path(p).name for s in sf]
        types[8] &= not types[9]  # tflite &= not edgetpu
        triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
        return types + [triton]

    @staticmethod
    def _load_metadata(f=Path('path/to/meta.yaml')):
        """
        Loads the metadata from a yaml file

        Args:
            f: The path to the metadata file.
        """

        # Load metadata from meta.yaml if it exists
        if f.exists():
            d = yaml_load(f)
            return d['stride'], d['names']  # assign stride, names
        return None, None

__init__(weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True)

MultiBackend class for python inference on various platforms using Ultralytics YOLO.

Parameters:

Name Type Description Default
weights str

The path to the weights file. Default: 'yolov8n.pt'

'yolov8n.pt'
device torch.device

The device to run the model on.

torch.device('cpu')
dnn bool

Use OpenCV's DNN module for inference if True, defaults to False.

False
data dict

Additional data, optional

None
fp16 bool

If True, use half precision. Default: False

False
fuse bool

Whether to fuse the model or not. Default: True

True
Supported formats and their naming conventions
Format Suffix
PyTorch *.pt
TorchScript *.torchscript
ONNX Runtime *.onnx
ONNX OpenCV DNN *.onnx --dnn
OpenVINO *.xml
CoreML *.mlmodel
TensorRT *.engine
TensorFlow SavedModel *_saved_model
TensorFlow GraphDef *.pb
TensorFlow Lite *.tflite
TensorFlow Edge TPU *_edgetpu.tflite
PaddlePaddle *_paddle_model
Source code in ultralytics/nn/autobackend.py
def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
    """
    MultiBackend class for python inference on various platforms using Ultralytics YOLO.

    Args:
        weights (str): The path to the weights file. Default: 'yolov8n.pt'
        device (torch.device): The device to run the model on.
        dnn (bool): Use OpenCV's DNN module for inference if True, defaults to False.
        data (dict): Additional data, optional
        fp16 (bool): If True, use half precision. Default: False
        fuse (bool): Whether to fuse the model or not. Default: True

    Supported formats and their naming conventions:
        | Format                | Suffix           |
        |-----------------------|------------------|
        | PyTorch               | *.pt             |
        | TorchScript           | *.torchscript    |
        | ONNX Runtime          | *.onnx           |
        | ONNX OpenCV DNN       | *.onnx --dnn     |
        | OpenVINO              | *.xml            |
        | CoreML                | *.mlmodel        |
        | TensorRT              | *.engine         |
        | TensorFlow SavedModel | *_saved_model    |
        | TensorFlow GraphDef   | *.pb             |
        | TensorFlow Lite       | *.tflite         |
        | TensorFlow Edge TPU   | *_edgetpu.tflite |
        | PaddlePaddle          | *_paddle_model   |
    """
    super().__init__()
    w = str(weights[0] if isinstance(weights, list) else weights)
    nn_module = isinstance(weights, torch.nn.Module)
    pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
    fp16 &= pt or jit or onnx or engine or nn_module  # FP16
    nhwc = coreml or saved_model or pb or tflite or edgetpu  # BHWC formats (vs torch BCWH)
    stride = 32  # default stride
    model = None  # TODO: resolves ONNX inference, verify effect on other backends
    cuda = torch.cuda.is_available() and device.type != 'cpu'  # use CUDA
    if not (pt or triton or nn_module):
        w = attempt_download_asset(w)  # download if not local

    # NOTE: special case: in-memory pytorch model
    if nn_module:
        model = weights.to(device)
        model = model.fuse() if fuse else model
        names = model.module.names if hasattr(model, 'module') else model.names  # get class names
        stride = max(int(model.stride.max()), 32)  # model stride
        model.half() if fp16 else model.float()
        self.model = model  # explicitly assign for to(), cpu(), cuda(), half()
        pt = True
    elif pt:  # PyTorch
        from ultralytics.nn.tasks import attempt_load_weights
        model = attempt_load_weights(weights if isinstance(weights, list) else w,
                                     device=device,
                                     inplace=True,
                                     fuse=fuse)
        stride = max(int(model.stride.max()), 32)  # model stride
        names = model.module.names if hasattr(model, 'module') else model.names  # get class names
        model.half() if fp16 else model.float()
        self.model = model  # explicitly assign for to(), cpu(), cuda(), half()
    elif jit:  # TorchScript
        LOGGER.info(f'Loading {w} for TorchScript inference...')
        extra_files = {'config.txt': ''}  # model metadata
        model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
        model.half() if fp16 else model.float()
        if extra_files['config.txt']:  # load metadata dict
            d = json.loads(extra_files['config.txt'],
                           object_hook=lambda d: {int(k) if k.isdigit() else k: v
                                                  for k, v in d.items()})
            stride, names = int(d['stride']), d['names']
    elif dnn:  # ONNX OpenCV DNN
        LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
        check_requirements('opencv-python>=4.5.4')
        net = cv2.dnn.readNetFromONNX(w)
    elif onnx:  # ONNX Runtime
        LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
        check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
        import onnxruntime
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
        session = onnxruntime.InferenceSession(w, providers=providers)
        output_names = [x.name for x in session.get_outputs()]
        meta = session.get_modelmeta().custom_metadata_map  # metadata
        if 'stride' in meta:
            stride, names = int(meta['stride']), eval(meta['names'])
    elif xml:  # OpenVINO
        LOGGER.info(f'Loading {w} for OpenVINO inference...')
        check_requirements('openvino')  # requires openvino-dev: https://pypi.org/project/openvino-dev/
        from openvino.runtime import Core, Layout, get_batch  # noqa
        ie = Core()
        if not Path(w).is_file():  # if not *.xml
            w = next(Path(w).glob('*.xml'))  # get *.xml file from *_openvino_model dir
        network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
        if network.get_parameters()[0].get_layout().empty:
            network.get_parameters()[0].set_layout(Layout("NCHW"))
        batch_dim = get_batch(network)
        if batch_dim.is_static:
            batch_size = batch_dim.get_length()
        executable_network = ie.compile_model(network, device_name="CPU")  # device_name="MYRIAD" for Intel NCS2
        stride, names = self._load_metadata(Path(w).with_suffix('.yaml'))  # load metadata
    elif engine:  # TensorRT
        LOGGER.info(f'Loading {w} for TensorRT inference...')
        import tensorrt as trt  # https://developer.nvidia.com/nvidia-tensorrt-download
        check_version(trt.__version__, '7.0.0', hard=True)  # require tensorrt>=7.0.0
        if device.type == 'cpu':
            device = torch.device('cuda:0')
        Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
        logger = trt.Logger(trt.Logger.INFO)
        with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
            model = runtime.deserialize_cuda_engine(f.read())
        context = model.create_execution_context()
        bindings = OrderedDict()
        output_names = []
        fp16 = False  # default updated below
        dynamic = False
        for i in range(model.num_bindings):
            name = model.get_binding_name(i)
            dtype = trt.nptype(model.get_binding_dtype(i))
            if model.binding_is_input(i):
                if -1 in tuple(model.get_binding_shape(i)):  # dynamic
                    dynamic = True
                    context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
                if dtype == np.float16:
                    fp16 = True
            else:  # output
                output_names.append(name)
            shape = tuple(context.get_binding_shape(i))
            im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
            bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
        binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
        batch_size = bindings['images'].shape[0]  # if dynamic, this is instead max batch size
    elif coreml:  # CoreML
        LOGGER.info(f'Loading {w} for CoreML inference...')
        import coremltools as ct
        model = ct.models.MLModel(w)
    elif saved_model:  # TF SavedModel
        LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
        import tensorflow as tf
        keras = False  # assume TF1 saved_model
        model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
    elif pb:  # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
        LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
        import tensorflow as tf

        def wrap_frozen_graph(gd, inputs, outputs):
            x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped
            ge = x.graph.as_graph_element
            return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))

        def gd_outputs(gd):
            name_list, input_list = [], []
            for node in gd.node:  # tensorflow.core.framework.node_def_pb2.NodeDef
                name_list.append(node.name)
                input_list.extend(node.input)
            return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))

        gd = tf.Graph().as_graph_def()  # TF GraphDef
        with open(w, 'rb') as f:
            gd.ParseFromString(f.read())
        frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
    elif tflite or edgetpu:  # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
        try:  # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
            from tflite_runtime.interpreter import Interpreter, load_delegate
        except ImportError:
            import tensorflow as tf
            Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
        if edgetpu:  # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
            LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
            delegate = {
                'Linux': 'libedgetpu.so.1',
                'Darwin': 'libedgetpu.1.dylib',
                'Windows': 'edgetpu.dll'}[platform.system()]
            interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
        else:  # TFLite
            LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
            interpreter = Interpreter(model_path=w)  # load TFLite model
        interpreter.allocate_tensors()  # allocate
        input_details = interpreter.get_input_details()  # inputs
        output_details = interpreter.get_output_details()  # outputs
    elif tfjs:  # TF.js
        raise NotImplementedError('ERROR: YOLOv8 TF.js inference is not supported')
    elif paddle:  # PaddlePaddle
        LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
        check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
        import paddle.inference as pdi
        if not Path(w).is_file():  # if not *.pdmodel
            w = next(Path(w).rglob('*.pdmodel'))  # get *.xml file from *_openvino_model dir
        weights = Path(w).with_suffix('.pdiparams')
        config = pdi.Config(str(w), str(weights))
        if cuda:
            config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
        predictor = pdi.create_predictor(config)
        input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
        output_names = predictor.get_output_names()
    elif triton:  # NVIDIA Triton Inference Server
        LOGGER.info('Triton Inference Server not supported...')
        '''
        TODO:
        check_requirements('tritonclient[all]')
        from utils.triton import TritonRemoteModel
        model = TritonRemoteModel(url=w)
        nhwc = model.runtime.startswith("tensorflow")
        '''
    else:
        raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
                                  f"https://docs.ultralytics.com/reference/nn/")

    # class names
    if 'names' not in locals():  # names missing
        names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}  # assign default
    names = check_class_names(names)

    self.__dict__.update(locals())  # assign all variables to self

forward(im, augment=False, visualize=False)

Runs inference on the YOLOv8 MultiBackend model.

Parameters:

Name Type Description Default
im torch.Tensor

The image tensor to perform inference on.

required
augment bool

whether to perform data augmentation during inference, defaults to False

False
visualize bool

whether to visualize the output predictions, defaults to False

False

Returns:

Type Description
tuple

Tuple containing the raw output tensor, and the processed output for visualization (if visualize=True)

Source code in ultralytics/nn/autobackend.py
def forward(self, im, augment=False, visualize=False):
    """
    Runs inference on the YOLOv8 MultiBackend model.

    Args:
        im (torch.Tensor): The image tensor to perform inference on.
        augment (bool): whether to perform data augmentation during inference, defaults to False
        visualize (bool): whether to visualize the output predictions, defaults to False

    Returns:
        (tuple): Tuple containing the raw output tensor, and the processed output for visualization (if visualize=True)
    """
    b, ch, h, w = im.shape  # batch, channel, height, width
    if self.fp16 and im.dtype != torch.float16:
        im = im.half()  # to FP16
    if self.nhwc:
        im = im.permute(0, 2, 3, 1)  # torch BCHW to numpy BHWC shape(1,320,192,3)

    if self.pt or self.nn_module:  # PyTorch
        y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
    elif self.jit:  # TorchScript
        y = self.model(im)
    elif self.dnn:  # ONNX OpenCV DNN
        im = im.cpu().numpy()  # torch to numpy
        self.net.setInput(im)
        y = self.net.forward()
    elif self.onnx:  # ONNX Runtime
        im = im.cpu().numpy()  # torch to numpy
        y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
    elif self.xml:  # OpenVINO
        im = im.cpu().numpy()  # FP32
        y = list(self.executable_network([im]).values())
    elif self.engine:  # TensorRT
        if self.dynamic and im.shape != self.bindings['images'].shape:
            i = self.model.get_binding_index('images')
            self.context.set_binding_shape(i, im.shape)  # reshape if dynamic
            self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
            for name in self.output_names:
                i = self.model.get_binding_index(name)
                self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
        s = self.bindings['images'].shape
        assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
        self.binding_addrs['images'] = int(im.data_ptr())
        self.context.execute_v2(list(self.binding_addrs.values()))
        y = [self.bindings[x].data for x in sorted(self.output_names)]
    elif self.coreml:  # CoreML
        im = im.cpu().numpy()
        im = Image.fromarray((im[0] * 255).astype('uint8'))
        # im = im.resize((192, 320), Image.ANTIALIAS)
        y = self.model.predict({'image': im})  # coordinates are xywh normalized
        if 'confidence' in y:
            box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels
            conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
            y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
        else:
            y = list(reversed(y.values()))  # reversed for segmentation models (pred, proto)
    elif self.paddle:  # PaddlePaddle
        im = im.cpu().numpy().astype(np.float32)
        self.input_handle.copy_from_cpu(im)
        self.predictor.run()
        y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
    elif self.triton:  # NVIDIA Triton Inference Server
        y = self.model(im)
    else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
        im = im.cpu().numpy()
        if self.saved_model:  # SavedModel
            y = self.model(im, training=False) if self.keras else self.model(im)
        elif self.pb:  # GraphDef
            y = self.frozen_func(x=self.tf.constant(im))
        else:  # Lite or Edge TPU
            input = self.input_details[0]
            int8 = input['dtype'] == np.uint8  # is TFLite quantized uint8 model
            if int8:
                scale, zero_point = input['quantization']
                im = (im / scale + zero_point).astype(np.uint8)  # de-scale
            self.interpreter.set_tensor(input['index'], im)
            self.interpreter.invoke()
            y = []
            for output in self.output_details:
                x = self.interpreter.get_tensor(output['index'])
                if int8:
                    scale, zero_point = output['quantization']
                    x = (x.astype(np.float32) - zero_point) * scale  # re-scale
                y.append(x)
        y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
        y[0][..., :4] *= [w, h, w, h]  # xywh normalized to pixels

    if isinstance(y, (list, tuple)):
        return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
    else:
        return self.from_numpy(y)

from_numpy(x)

Convert a numpy array to a tensor.

Parameters:

Name Type Description Default
x np.ndarray

The array to be converted.

required

Returns:

Type Description
torch.Tensor

The converted tensor

Source code in ultralytics/nn/autobackend.py
def from_numpy(self, x):
    """
     Convert a numpy array to a tensor.

     Args:
         x (np.ndarray): The array to be converted.

     Returns:
         (torch.Tensor): The converted tensor
     """
    return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x

warmup(imgsz=(1, 3, 640, 640))

Warm up the model by running one forward pass with a dummy input.

Parameters:

Name Type Description Default
imgsz tuple

The shape of the dummy input tensor in the format (batch_size, channels, height, width)

(1, 3, 640, 640)

Returns:

Type Description
None

This method runs the forward pass and don't return any value

Source code in ultralytics/nn/autobackend.py
def warmup(self, imgsz=(1, 3, 640, 640)):
    """
    Warm up the model by running one forward pass with a dummy input.

    Args:
        imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)

    Returns:
        (None): This method runs the forward pass and don't return any value
    """
    warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
    if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
        im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device)  # input
        for _ in range(2 if self.jit else 1):  #
            self.forward(im)  # warmup

BaseModel

Bases: nn.Module

The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.

Source code in ultralytics/nn/tasks.py
class BaseModel(nn.Module):
    """
    The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
    """

    def forward(self, x, profile=False, visualize=False):
        """
        Forward pass of the model on a single scale.
        Wrapper for `_forward_once` method.

        Args:
            x (torch.Tensor): The input image tensor
            profile (bool): Whether to profile the model, defaults to False
            visualize (bool): Whether to return the intermediate feature maps, defaults to False

        Returns:
            (torch.Tensor): The output of the network.
        """
        return self._forward_once(x, profile, visualize)

    def _forward_once(self, x, profile=False, visualize=False):
        """
        Perform a forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor to the model
            profile (bool):  Print the computation time of each layer if True, defaults to False.
            visualize (bool): Save the feature maps of the model if True, defaults to False

        Returns:
            (torch.Tensor): The last output of the model.
        """
        y, dt = [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                LOGGER.info('visualize feature not yet supported')
                # TODO: feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x

    def _profile_one_layer(self, m, x, dt):
        """
        Profile the computation time and FLOPs of a single layer of the model on a given input.
        Appends the results to the provided list.

        Args:
            m (nn.Module): The layer to be profiled.
            x (torch.Tensor): The input data to the layer.
            dt (list): A list to store the computation time of the layer.

        Returns:
            None
        """
        c = m == self.model[-1]  # is final layer, copy input as inplace fix
        o = thop.profile(m, inputs=(x.clone() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPs
        t = time_sync()
        for _ in range(10):
            m(x.clone() if c else x)
        dt.append((time_sync() - t) * 100)
        if m == self.model[0]:
            LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  module")
        LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f}  {m.type}')
        if c:
            LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total")

    def fuse(self):
        """
        Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
        computation efficiency.

        Returns:
            (nn.Module): The fused model is returned.
        """
        if not self.is_fused():
            for m in self.model.modules():
                if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
                    m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                    delattr(m, 'bn')  # remove batchnorm
                    m.forward = m.forward_fuse  # update forward
                if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
                    m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
                    delattr(m, 'bn')  # remove batchnorm
                    m.forward = m.forward_fuse  # update forward
            self.info()

        return self

    def is_fused(self, thresh=10):
        """
        Check if the model has less than a certain threshold of BatchNorm layers.

        Args:
            thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.

        Returns:
            (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
        """
        bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
        return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model

    def info(self, verbose=False, imgsz=640):
        """
        Prints model information

        Args:
            verbose (bool): if True, prints out the model information. Defaults to False
            imgsz (int): the size of the image that the model will be trained on. Defaults to 640
        """
        model_info(self, verbose, imgsz)

    def _apply(self, fn):
        """
        `_apply()` is a function that applies a function to all the tensors in the model that are not
        parameters or registered buffers

        Args:
            fn: the function to apply to the model

        Returns:
            A model that is a Detect() object.
        """
        self = super()._apply(fn)
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment)):
            m.stride = fn(m.stride)
            m.anchors = fn(m.anchors)
            m.strides = fn(m.strides)
        return self

    def load(self, weights):
        """
        This function loads the weights of the model from a file

        Args:
            weights (str): The weights to load into the model.
        """
        # Force all tasks to implement this function
        raise NotImplementedError("This function needs to be implemented by derived classes!")

forward(x, profile=False, visualize=False)

Forward pass of the model on a single scale. Wrapper for _forward_once method.

Parameters:

Name Type Description Default
x torch.Tensor

The input image tensor

required
profile bool

Whether to profile the model, defaults to False

False
visualize bool

Whether to return the intermediate feature maps, defaults to False

False

Returns:

Type Description
torch.Tensor

The output of the network.

Source code in ultralytics/nn/tasks.py
def forward(self, x, profile=False, visualize=False):
    """
    Forward pass of the model on a single scale.
    Wrapper for `_forward_once` method.

    Args:
        x (torch.Tensor): The input image tensor
        profile (bool): Whether to profile the model, defaults to False
        visualize (bool): Whether to return the intermediate feature maps, defaults to False

    Returns:
        (torch.Tensor): The output of the network.
    """
    return self._forward_once(x, profile, visualize)

fuse()

Fuse the Conv2d() and BatchNorm2d() layers of the model into a single layer, in order to improve the computation efficiency.

Returns:

Type Description
nn.Module

The fused model is returned.

Source code in ultralytics/nn/tasks.py
def fuse(self):
    """
    Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
    computation efficiency.

    Returns:
        (nn.Module): The fused model is returned.
    """
    if not self.is_fused():
        for m in self.model.modules():
            if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                delattr(m, 'bn')  # remove batchnorm
                m.forward = m.forward_fuse  # update forward
            if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
                m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
                delattr(m, 'bn')  # remove batchnorm
                m.forward = m.forward_fuse  # update forward
        self.info()

    return self

info(verbose=False, imgsz=640)

Prints model information

Parameters:

Name Type Description Default
verbose bool

if True, prints out the model information. Defaults to False

False
imgsz int

the size of the image that the model will be trained on. Defaults to 640

640
Source code in ultralytics/nn/tasks.py
def info(self, verbose=False, imgsz=640):
    """
    Prints model information

    Args:
        verbose (bool): if True, prints out the model information. Defaults to False
        imgsz (int): the size of the image that the model will be trained on. Defaults to 640
    """
    model_info(self, verbose, imgsz)

is_fused(thresh=10)

Check if the model has less than a certain threshold of BatchNorm layers.

Parameters:

Name Type Description Default
thresh int

The threshold number of BatchNorm layers. Default is 10.

10

Returns:

Type Description
bool

True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.

Source code in ultralytics/nn/tasks.py
def is_fused(self, thresh=10):
    """
    Check if the model has less than a certain threshold of BatchNorm layers.

    Args:
        thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.

    Returns:
        (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
    """
    bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
    return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model

load(weights)

This function loads the weights of the model from a file

Parameters:

Name Type Description Default
weights str

The weights to load into the model.

required
Source code in ultralytics/nn/tasks.py
def load(self, weights):
    """
    This function loads the weights of the model from a file

    Args:
        weights (str): The weights to load into the model.
    """
    # Force all tasks to implement this function
    raise NotImplementedError("This function needs to be implemented by derived classes!")

Modules

TODO