基于pyskl的poseC3D练习自己的数据集

[复制链接]
查看1217 | 回复0 | 2023-8-23 11:36:08 | 显示全部楼层 |阅读模式
        最近在研究视频动作识别,看了不少的相关算法,主要有基于MMDetection框架下的一些列的研究,有直接对视频举行识别,获取人为动作,比如slowfast等等,但是每每也有肯定的局限性,而我一直做的是围绕骨骼点的相关开发,当初也使用骨骼的方法,但是结果不佳,最近看到一篇新出来的基于骨骼信息的视频动作识别,PoseC3D算法,因此比力感兴趣,这里给出了论文所在和开源的代码所在。
论文所在:https://arxiv.org/abs/2104.13586
代码所在:https://github.com/kennymckormick/pyskl
这里,我没有写自己对该文章的理解,反面我有时间会继续补充,我就先纪录一下使用该算法怎样练习自己的数据集。
一、情况设置

        要设置该算法的情况,最好先设置好MMDetection的情况,根据官网给的安装教程,照旧比力简单的。这里要注意安装mmcv、mmcv-full、mmdet和mmpose,版本不能太高,如果安装了最新的版本会报无法编译,这里我就给出直接的安装版本,
mmcv=1.3.18,mmcv-full=1.3.18, mmdet=2.23.0, mmpose=0.24.0
这里还提示一下,安装MMDetection的时间,最好也把Detectron2也装一下,使用官网的命令行安装就可以了,这里贴一下安装命令。
  1. python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
复制代码
这样在使用下面的命令安装和编译就没有题目了。
  1. git clone https://github.com/kennymckormick/pyskl.git
  2. cd pyskl
  3. # Please first install pytorch according to instructions on the official website: https://pytorch.org/get-started/locally/. Please use pytorch with version smaller than 1.11.0 and larger (or equal) than 1.5.0
  4. pip install -r requirements.txt
  5. pip install -e .
复制代码
情况设置完,最好跑一下提供的demo。
  1. python demo/demo_skeleton.py demo/ntu_sample.avi demo/demo.mp4
复制代码
 输出了视频结果,那么就Ok了。
二、数据准备

        官网给了一份练习自己数据的格式文档和数据制作代码,根据代码可以知道,起首需要准备各个动作的视频和标签,我准备的格式如下:
  1. pyskl_data:
  2. train_lable_list.txt
  3. val_lable_list.txt
  4. ----train
  5.       A1.mp4
  6.       B1.mp4
  7.       ......
  8. ----val
  9.       A2.mp4
  10.       B2.mp4
  11.       ......
复制代码
         我先将每个视频的定名方式修改为“名称+标签”,这样反面在生成list.txt的时间,只要读取最后一位就可以知道其种别。分别生成了train_label_list.txt和val_label_list.txt,这样就可以运行制作数据代码。在运行之前,我修改了源代码,我将各个需要的模型和设置文档下载下来,同时我把分布式运行给表明了,要否则会报错,报错我反面会贴出来,固然如果使用分布式运行,那最好好好确认一下自己服务器的情况是否设置好,否则会浪费很多时间。我修改的代码如下: 
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. import pdb
  6. import decord
  7. import mmcv
  8. import numpy as np
  9. import torch.distributed as dist
  10. from mmcv.runner import get_dist_info, init_dist
  11. from tqdm import tqdm
  12. from pyskl.smp import mrlines
  13. try:
  14.     import mmdet
  15.     from mmdet.apis import inference_detector, init_detector
  16. except (ImportError, ModuleNotFoundError):
  17.     raise ImportError('Failed to import `inference_detector` and '
  18.                       '`init_detector` form `mmdet.apis`. These apis are '
  19.                       'required in this script! ')
  20. try:
  21.     import mmpose
  22.     from mmpose.apis import inference_top_down_pose_model, init_pose_model
  23. except (ImportError, ModuleNotFoundError):
  24.     raise ImportError('Failed to import `inference_top_down_pose_model` and '
  25.                       '`init_pose_model` form `mmpose.apis`. These apis are '
  26.                       'required in this script! ')
  27. default_mmdet_root = osp.dirname(mmdet.__path__[0])
  28. default_mmpose_root = osp.dirname(mmpose.__path__[0])
  29. def extract_frame(video_path):
  30.     vid = decord.VideoReader(video_path)
  31.     return [x.asnumpy() for x in vid]
  32. def detection_inference(model, frames):
  33.     results = []
  34.     for frame in frames:
  35.         result = inference_detector(model, frame)
  36.         results.append(result)
  37.     return results
  38. def pose_inference(model, frames, det_results):
  39.     assert len(frames) == len(det_results)
  40.     total_frames = len(frames)
  41.     num_person = max([len(x) for x in det_results])
  42.     kp = np.zeros((num_person, total_frames, 17, 3), dtype=np.float32)
  43.     for i, (f, d) in enumerate(zip(frames, det_results)):
  44.         # Align input format
  45.         d = [dict(bbox=x) for x in list(d)]
  46.         pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
  47.         for j, item in enumerate(pose):
  48.             kp[j, i] = item['keypoints']
  49.     return kp
  50. def parse_args():
  51.     parser = argparse.ArgumentParser(
  52.         description='Generate 2D pose annotations for a custom video dataset')
  53.     # * Both mmdet and mmpose should be installed from source
  54.     parser.add_argument('--mmdet-root', type=str, default=default_mmdet_root)
  55.     parser.add_argument('--mmpose-root', type=str, default=default_mmpose_root)
  56.     parser.add_argument('--det-config', type=str, default='demo/faster_rcnn_r50_fpn_2x_coco.py')
  57.     parser.add_argument('--det-ckpt', type=str, default='weights/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth')
  58.     parser.add_argument('--pose-config', type=str, default='demo/hrnet_w32_coco_256x192.py')
  59.     parser.add_argument('--pose-ckpt', type=str, default='weights/hrnet_w32_coco_256x192-c78dce93_20200708.pth')
  60.     # * Only det boxes with score larger than det_score_thr will be kept
  61.     parser.add_argument('--det-score-thr', type=float, default=0.7)
  62.     # * Only det boxes with large enough sizes will be kept,
  63.     parser.add_argument('--det-area-thr', type=float, default=1600)
  64.     # * Accepted formats for each line in video_list are:
  65.     # * 1. "xxx.mp4" ('label' is missing, the dataset can be used for inference, but not training)
  66.     # * 2. "xxx.mp4 label" ('label' is an integer (category index),
  67.     # * the result can be used for both training & testing)
  68.     # * All lines should take the same format.
  69.     parser.add_argument('--video-list', type=str, help='the list of source videos')
  70.     # * out should ends with '.pkl'
  71.     parser.add_argument('--out', type=str, help='output pickle name')
  72.     parser.add_argument('--tmpdir', type=str, default='tmp')
  73.     parser.add_argument('--local_rank', type=int, default=1)
  74.     args = parser.parse_args()
  75.     # pdb.set_trace()
  76.    # if 'RANK' not in os.environ:
  77.    #     os.environ['RANK'] = str(args.local_rank)
  78.    #     os.environ['WORLD_SIZE'] = str(1)
  79.     # os.environ['MASTER_ADDR'] = 'localhost'
  80.     # os.environ['MASTER_PORT'] = '12345'
  81.     args = parser.parse_args()
  82.     return args
  83. def main():
  84.     args = parse_args()
  85.     assert args.out.endswith('.pkl')
  86.     lines = mrlines(args.video_list)
  87.     lines = [x.split() for x in lines]
  88.     # * We set 'frame_dir' as the base name (w/o. suffix) of each video
  89.     assert len(lines[0]) in [1, 2]
  90.     if len(lines[0]) == 1:
  91.         annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0]) for x in lines]
  92.     else:
  93.         annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0], label=int(x[1])) for x in lines]
  94.     rank=0  #添加该
  95.     world_size=1#添加
  96.     # init_dist('pytorch', backend='nccl')
  97.     # rank, world_size = get_dist_info()
  98.     #
  99.     # if rank == 0:
  100.     #     os.makedirs(args.tmpdir, exist_ok=True)
  101.     # dist.barrier()
  102.     my_part = annos
  103.     # my_part = annos[rank::world_size]
  104.     print("from det_model")
  105.     det_model = init_detector(args.det_config, args.det_ckpt, 'cuda')
  106.     assert det_model.CLASSES[0] == 'person', 'A detector trained on COCO is required'
  107.     print("from pose_model")
  108.     pose_model = init_pose_model(args.pose_config, args.pose_ckpt, 'cuda')
  109.     n=0
  110.     for anno in tqdm(my_part):
  111.         frames = extract_frame(anno['filename'])
  112.         print("anno['filename",anno['filename'])
  113.         det_results = detection_inference(det_model, frames)
  114.         # * Get detection results for human
  115.         det_results = [x[0] for x in det_results]
  116.         for i, res in enumerate(det_results):
  117.             # * filter boxes with small scores
  118.             res = res[res[:, 4] >= args.det_score_thr]
  119.             # * filter boxes with small areas
  120.             box_areas = (res[:, 3] - res[:, 1]) * (res[:, 2] - res[:, 0])
  121.             assert np.all(box_areas >= 0)
  122.             res = res[box_areas >= args.det_area_thr]
  123.             det_results[i] = res
  124.         pose_results = pose_inference(pose_model, frames, det_results)
  125.         shape = frames[0].shape[:2]
  126.         anno['img_shape'] = anno['original_shape'] = shape
  127.         anno['total_frames'] = len(frames)
  128.         anno['num_person_raw'] = pose_results.shape[0]
  129.         anno['keypoint'] = pose_results[..., :2].astype(np.float16)
  130.         anno['keypoint_score'] = pose_results[..., 2].astype(np.float16)
  131.         anno.pop('filename')
  132.     mmcv.dump(my_part, osp.join(args.tmpdir, f'part_{rank}.pkl'))
  133.     # dist.barrier()
  134.     if rank == 0:
  135.         parts = [mmcv.load(osp.join(args.tmpdir, f'part_{i}.pkl')) for i in range(world_size)]
  136.         rem = len(annos) % world_size
  137.         if rem:
  138.             for i in range(rem, world_size):
  139.                 parts[i].append(None)
  140.         ordered_results = []
  141.         for res in zip(*parts):
  142.             ordered_results.extend(list(res))
  143.         ordered_results = ordered_results[:len(annos)]
  144.         mmcv.dump(ordered_results, args.out)
  145. if __name__ == '__main__':
  146.     main()
复制代码
我这里修改一些代码,起首是将下面的代码表明了
  1. # if 'RANK' not in os.environ:
  2. #     os.environ['RANK'] = str(args.local_rank)
  3. #     os.environ['WORLD_SIZE'] = str(1)
  4. # os.environ['MASTER_ADDR'] = 'localhost'
  5. # os.environ['MASTER_PORT'] = '12345'
复制代码
题目1:keyError:"RANK”

 追踪题目发现,代码中反面用的是"RANK",而情况赋值的是"LOCAL_RANK",而我的体系情况中也没有"RANK",因此我就将"LOCAL_RANK"修改为"RANK".
题目2:environment variable WORD_SIZE

 查看体系情况,同样没有"WORLD_SIZE",我也就给其赋值了,厥后运行报错没有"MASTER_ADDR"和"MASTER_PORT",因此我都分别赋值,复制结果如上面表明,固然运行不报错,但是在模型构建的时间卡住了,
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则