11# Copyright (c) OpenMMLab. All rights reserved.
2+ from typing import Tuple
3+
24import torch
35from mmcv .cnn import ConvModule
46from torch import nn as nn
57
68from mmdet3d .models .layers .pointnet_modules import build_sa_module
79from mmdet3d .registry import MODELS
10+ from mmdet3d .utils import OptConfigType
811from .base_pointnet import BasePointNet
912
13+ ThreeTupleIntType = Tuple [Tuple [Tuple [int , int , int ]]]
14+ TwoTupleIntType = Tuple [Tuple [int , int , int ]]
15+ TwoTupleStrType = Tuple [Tuple [str ]]
16+
1017
1118@MODELS .register_module ()
1219class PointNet2SAMSG (BasePointNet ):
@@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
2229 sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
2330 aggregation_channels (tuple[int]): Out channels of aggregation
2431 multi-scale grouping features.
25- fps_mods (tuple[int]) : Mod of FPS for each SA module.
32+ fps_mods Sequence[Tuple[str]] : Mod of FPS for each SA module.
2633 fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
2734 points which each SA module samples.
2835 dilated_group (tuple[bool]): Whether to use dilated ball query for
@@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
3845 """
3946
4047 def __init__ (self ,
41- in_channels ,
42- num_points = (2048 , 1024 , 512 , 256 ),
43- radii = ((0.2 , 0.4 , 0.8 ), (0.4 , 0.8 , 1.6 ), (1.6 , 3.2 , 4.8 )),
44- num_samples = ((32 , 32 , 64 ), (32 , 32 , 64 ), (32 , 32 , 32 )),
45- sa_channels = (((16 , 16 , 32 ), (16 , 16 , 32 ), (32 , 32 , 64 )),
46- ((64 , 64 , 128 ), (64 , 64 , 128 ), (64 , 96 , 128 )),
47- ((128 , 128 , 256 ), (128 , 192 , 256 ), (128 , 256 ,
48- 256 ))),
49- aggregation_channels = (64 , 128 , 256 ),
50- fps_mods = (('D-FPS' ), ('FS' ), ('F-FPS' , 'D-FPS' )),
51- fps_sample_range_lists = ((- 1 ), (- 1 ), (512 , - 1 )),
52- dilated_group = (True , True , True ),
53- out_indices = (2 , ),
54- norm_cfg = dict (type = 'BN2d' ),
55- sa_cfg = dict (
48+ in_channels : int ,
49+ num_points : Tuple [int ] = (2048 , 1024 , 512 , 256 ),
50+ radii : Tuple [Tuple [float , float , float ]] = (
51+ (0.2 , 0.4 , 0.8 ),
52+ (0.4 , 0.8 , 1.6 ),
53+ (1.6 , 3.2 , 4.8 ),
54+ ),
55+ num_samples : TwoTupleIntType = ((32 , 32 , 64 ), (32 , 32 , 64 ),
56+ (32 , 32 , 32 )),
57+ sa_channels : ThreeTupleIntType = (((16 , 16 , 32 ), (16 , 16 , 32 ),
58+ (32 , 32 , 64 )),
59+ ((64 , 64 , 128 ),
60+ (64 , 64 , 128 ), (64 , 96 ,
61+ 128 )),
62+ ((128 , 128 , 256 ),
63+ (128 , 192 , 256 ), (128 , 256 ,
64+ 256 ))),
65+ aggregation_channels : Tuple [int ] = (64 , 128 , 256 ),
66+ fps_mods : TwoTupleStrType = (('D-FPS' ), ('FS' ), ('F-FPS' ,
67+ 'D-FPS' )),
68+ fps_sample_range_lists : TwoTupleIntType = ((- 1 ), (- 1 ), (512 ,
69+ - 1 )),
70+ dilated_group : Tuple [bool ] = (True , True , True ),
71+ out_indices : Tuple [int ] = (2 , ),
72+ norm_cfg : dict = dict (type = 'BN2d' ),
73+ sa_cfg : dict = dict (
5674 type = 'PointSAModuleMSG' ,
5775 pool_mod = 'max' ,
5876 use_xyz = True ,
5977 normalize_xyz = False ),
60- init_cfg = None ):
78+ init_cfg : OptConfigType = None ):
6179 super ().__init__ (init_cfg = init_cfg )
6280 self .num_sa = len (sa_channels )
6381 self .out_indices = out_indices
@@ -123,7 +141,7 @@ def __init__(self,
123141 bias = True ))
124142 sa_in_channel = cur_aggregation_channel
125143
126- def forward (self , points ):
144+ def forward (self , points : torch . Tensor ):
127145 """Forward pass.
128146
129147 Args:
0 commit comments