# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license from copy import copy from ultralytics.models import yolo from ultralytics.nn.tasks import PoseModel from ultralytics.utils import DEFAULT_CFG, LOGGER from ultralytics.utils.plotting import plot_images, plot_results class PoseTrainer(yolo.detect.DetectionTrainer): """ A class extending the DetectionTrainer class for training YOLO pose estimation models. This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization of pose keypoints alongside bounding boxes. Attributes: args (Dict): Configuration arguments for training. model (PoseModel): The pose estimation model being trained. data (Dict): Dataset configuration including keypoint shape information. loss_names (Tuple[str]): Names of the loss components used in training. Methods: get_model: Retrieves a pose estimation model with specified configuration. set_model_attributes: Sets keypoints shape attribute on the model. get_validator: Creates a validator instance for model evaluation. plot_training_samples: Visualizes training samples with keypoints. plot_metrics: Generates and saves training/validation metric plots. Examples: >>> from ultralytics.models.yolo.pose import PoseTrainer >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3) >>> trainer = PoseTrainer(overrides=args) >>> trainer.train() """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initialize a PoseTrainer object with specified configurations and overrides.""" if overrides is None: overrides = {} overrides["task"] = "pose" super().__init__(cfg, overrides, _callbacks) if isinstance(self.args.device, str) and self.args.device.lower() == "mps": LOGGER.warning( "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " "See https://github.com/ultralytics/ultralytics/issues/4031." ) def get_model(self, cfg=None, weights=None, verbose=True): """Get pose estimation model with specified configuration and weights.""" model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose) if weights: model.load(weights) return model def set_model_attributes(self): """Sets keypoints shape attribute of PoseModel.""" super().set_model_attributes() self.model.kpt_shape = self.data["kpt_shape"] def get_validator(self): """Returns an instance of the PoseValidator class for validation.""" self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss" return yolo.pose.PoseValidator( self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks ) def plot_training_samples(self, batch, ni): """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" images = batch["img"] kpts = batch["keypoints"] cls = batch["cls"].squeeze(-1) bboxes = batch["bboxes"] paths = batch["im_file"] batch_idx = batch["batch_idx"] plot_images( images, batch_idx, cls, bboxes, kpts=kpts, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg", on_plot=self.on_plot, ) def plot_metrics(self): """Plots training/val metrics.""" plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png