-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support SECOND-IoU and baseline results (#517)
* support SECOND-IoU * rename file * modify readme and fix model name in __init__.py
- Loading branch information
Showing
6 changed files
with
532 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import torch | ||
from .detector3d_template import Detector3DTemplate | ||
from ..model_utils.model_nms_utils import class_agnostic_nms | ||
from ...ops.roiaware_pool3d import roiaware_pool3d_utils | ||
|
||
|
||
class SECONDNetIoU(Detector3DTemplate): | ||
def __init__(self, model_cfg, num_class, dataset): | ||
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) | ||
self.module_list = self.build_networks() | ||
|
||
def forward(self, batch_dict): | ||
batch_dict['dataset_cfg'] = self.dataset.dataset_cfg | ||
for cur_module in self.module_list: | ||
batch_dict = cur_module(batch_dict) | ||
|
||
if self.training: | ||
loss, tb_dict, disp_dict = self.get_training_loss() | ||
|
||
ret_dict = { | ||
'loss': loss | ||
} | ||
return ret_dict, tb_dict, disp_dict | ||
else: | ||
pred_dicts, recall_dicts = self.post_processing(batch_dict) | ||
return pred_dicts, recall_dicts | ||
|
||
def get_training_loss(self): | ||
disp_dict = {} | ||
|
||
loss_rpn, tb_dict = self.dense_head.get_loss() | ||
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict) | ||
|
||
loss = loss_rpn + loss_rcnn | ||
return loss, tb_dict, disp_dict | ||
|
||
@staticmethod | ||
def cal_scores_by_npoints(cls_scores, iou_scores, num_points_in_gt, cls_thresh=10, iou_thresh=100): | ||
""" | ||
Args: | ||
cls_scores: (N) | ||
iou_scores: (N) | ||
num_points_in_gt: (N, 7+c) | ||
cls_thresh: scalar | ||
iou_thresh: scalar | ||
""" | ||
assert iou_thresh >= cls_thresh | ||
alpha = torch.zeros(cls_scores.shape, dtype=torch.float32).cuda() | ||
alpha[num_points_in_gt <= cls_thresh] = 0 | ||
alpha[num_points_in_gt >= iou_thresh] = 1 | ||
|
||
mask = ((num_points_in_gt > cls_thresh) & (num_points_in_gt < iou_thresh)) | ||
alpha[mask] = (num_points_in_gt[mask] - 10) / (iou_thresh - cls_thresh) | ||
|
||
scores = (1 - alpha) * cls_scores + alpha * iou_scores | ||
|
||
return scores | ||
|
||
def set_nms_score_by_class(self, iou_preds, cls_preds, label_preds, score_by_class): | ||
n_classes = torch.unique(label_preds).shape[0] | ||
nms_scores = torch.zeros(iou_preds.shape, dtype=torch.float32).cuda() | ||
for i in range(n_classes): | ||
mask = label_preds == (i + 1) | ||
class_name = self.class_names[i] | ||
score_type = score_by_class[class_name] | ||
if score_type == 'iou': | ||
nms_scores[mask] = iou_preds[mask] | ||
elif score_type == 'cls': | ||
nms_scores[mask] = cls_preds[mask] | ||
else: | ||
raise NotImplementedError | ||
|
||
return nms_scores | ||
|
||
def post_processing(self, batch_dict): | ||
""" | ||
Args: | ||
batch_dict: | ||
batch_size: | ||
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1) | ||
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C) | ||
cls_preds_normalized: indicate whether batch_cls_preds is normalized | ||
batch_index: optional (N1+N2+...) | ||
roi_labels: (B, num_rois) 1 .. num_classes | ||
Returns: | ||
""" | ||
post_process_cfg = self.model_cfg.POST_PROCESSING | ||
batch_size = batch_dict['batch_size'] | ||
recall_dict = {} | ||
pred_dicts = [] | ||
for index in range(batch_size): | ||
if batch_dict.get('batch_index', None) is not None: | ||
assert batch_dict['batch_cls_preds'].shape.__len__() == 2 | ||
batch_mask = (batch_dict['batch_index'] == index) | ||
else: | ||
assert batch_dict['batch_cls_preds'].shape.__len__() == 3 | ||
batch_mask = index | ||
|
||
box_preds = batch_dict['batch_box_preds'][batch_mask] | ||
iou_preds = batch_dict['batch_cls_preds'][batch_mask] | ||
cls_preds = batch_dict['roi_scores'][batch_mask] | ||
|
||
src_iou_preds = iou_preds | ||
src_box_preds = box_preds | ||
src_cls_preds = cls_preds | ||
assert iou_preds.shape[1] in [1, self.num_class] | ||
|
||
if not batch_dict['cls_preds_normalized']: | ||
iou_preds = torch.sigmoid(iou_preds) | ||
cls_preds = torch.sigmoid(cls_preds) | ||
|
||
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS: | ||
raise NotImplementedError | ||
else: | ||
iou_preds, label_preds = torch.max(iou_preds, dim=-1) | ||
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1 | ||
|
||
if post_process_cfg.NMS_CONFIG.get('SCORE_BY_CLASS', None) and \ | ||
post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'score_by_class': | ||
nms_scores = self.set_nms_score_by_class( | ||
iou_preds, cls_preds, label_preds, post_process_cfg.NMS_CONFIG.SCORE_BY_CLASS | ||
) | ||
elif post_process_cfg.NMS_CONFIG.get('SCORE_TYPE', None) == 'iou' or \ | ||
post_process_cfg.NMS_CONFIG.get('SCORE_TYPE', None) is None: | ||
nms_scores = iou_preds | ||
elif post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'cls': | ||
nms_scores = cls_preds | ||
elif post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'weighted_iou_cls': | ||
nms_scores = post_process_cfg.NMS_CONFIG.SCORE_WEIGHTS.iou * iou_preds + \ | ||
post_process_cfg.NMS_CONFIG.SCORE_WEIGHTS.cls * cls_preds | ||
elif post_process_cfg.NMS_CONFIG.SCORE_TYPE == 'num_pts_iou_cls': | ||
point_mask = (batch_dict['points'][:, 0] == batch_mask) | ||
batch_points = batch_dict['points'][point_mask][:, 1:4] | ||
|
||
num_pts_in_gt = roiaware_pool3d_utils.points_in_boxes_cpu( | ||
batch_points.cpu(), box_preds[:, 0:7].cpu() | ||
).sum(dim=1).float().cuda() | ||
|
||
score_thresh_cfg = post_process_cfg.NMS_CONFIG.SCORE_THRESH | ||
nms_scores = self.cal_scores_by_npoints( | ||
cls_preds, iou_preds, num_pts_in_gt, | ||
score_thresh_cfg.cls, score_thresh_cfg.iou | ||
) | ||
else: | ||
raise NotImplementedError | ||
|
||
selected, selected_scores = class_agnostic_nms( | ||
box_scores=nms_scores, box_preds=box_preds, | ||
nms_config=post_process_cfg.NMS_CONFIG, | ||
score_thresh=post_process_cfg.SCORE_THRESH | ||
) | ||
|
||
if post_process_cfg.OUTPUT_RAW_SCORE: | ||
raise NotImplementedError | ||
|
||
final_scores = selected_scores | ||
final_labels = label_preds[selected] | ||
final_boxes = box_preds[selected] | ||
|
||
recall_dict = self.generate_recall_record( | ||
box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds, | ||
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, | ||
thresh_list=post_process_cfg.RECALL_THRESH_LIST | ||
) | ||
|
||
record_dict = { | ||
'pred_boxes': final_boxes, | ||
'pred_scores': final_scores, | ||
'pred_labels': final_labels, | ||
'pred_cls_scores': cls_preds[selected], | ||
'pred_iou_scores': iou_preds[selected] | ||
} | ||
|
||
pred_dicts.append(record_dict) | ||
|
||
return pred_dicts, recall_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
from .partA2_head import PartA2FCHead | ||
from .pointrcnn_head import PointRCNNHead | ||
from .pvrcnn_head import PVRCNNHead | ||
from .second_head import SECONDHead | ||
from .roi_head_template import RoIHeadTemplate | ||
|
||
__all__ = { | ||
'RoIHeadTemplate': RoIHeadTemplate, | ||
'PartA2FCHead': PartA2FCHead, | ||
'PVRCNNHead': PVRCNNHead, | ||
'SECONDHead': SECONDHead, | ||
'PointRCNNHead': PointRCNNHead | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import torch | ||
import torch.nn as nn | ||
from .roi_head_template import RoIHeadTemplate | ||
from ...utils import common_utils, loss_utils | ||
|
||
|
||
class SECONDHead(RoIHeadTemplate): | ||
def __init__(self, input_channels, model_cfg, num_class=1): | ||
super().__init__(num_class=num_class, model_cfg=model_cfg) | ||
self.model_cfg = model_cfg | ||
|
||
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE | ||
pre_channel = self.model_cfg.ROI_GRID_POOL.IN_CHANNEL * GRID_SIZE * GRID_SIZE | ||
|
||
shared_fc_list = [] | ||
for k in range(0, self.model_cfg.SHARED_FC.__len__()): | ||
shared_fc_list.extend([ | ||
nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False), | ||
nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]), | ||
nn.ReLU() | ||
]) | ||
pre_channel = self.model_cfg.SHARED_FC[k] | ||
|
||
if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0: | ||
shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO)) | ||
|
||
self.shared_fc_layer = nn.Sequential(*shared_fc_list) | ||
|
||
self.iou_layers = self.make_fc_layers( | ||
input_channels=pre_channel, output_channels=1, fc_list=self.model_cfg.IOU_FC | ||
) | ||
self.init_weights(weight_init='xavier') | ||
|
||
def init_weights(self, weight_init='xavier'): | ||
if weight_init == 'kaiming': | ||
init_func = nn.init.kaiming_normal_ | ||
elif weight_init == 'xavier': | ||
init_func = nn.init.xavier_normal_ | ||
elif weight_init == 'normal': | ||
init_func = nn.init.normal_ | ||
else: | ||
raise NotImplementedError | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): | ||
if weight_init == 'normal': | ||
init_func(m.weight, mean=0, std=0.001) | ||
else: | ||
init_func(m.weight) | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def roi_grid_pool(self, batch_dict): | ||
""" | ||
Args: | ||
batch_dict: | ||
batch_size: | ||
rois: (B, num_rois, 7 + C) | ||
spatial_features_2d: (B, C, H, W) | ||
Returns: | ||
""" | ||
batch_size = batch_dict['batch_size'] | ||
rois = batch_dict['rois'].detach() | ||
spatial_features_2d = batch_dict['spatial_features_2d'].detach() | ||
height, width = spatial_features_2d.size(2), spatial_features_2d.size(3) | ||
|
||
dataset_cfg = batch_dict['dataset_cfg'] | ||
min_x = dataset_cfg.POINT_CLOUD_RANGE[0] | ||
min_y = dataset_cfg.POINT_CLOUD_RANGE[1] | ||
voxel_size_x = dataset_cfg.DATA_PROCESSOR[-1].VOXEL_SIZE[0] | ||
voxel_size_y = dataset_cfg.DATA_PROCESSOR[-1].VOXEL_SIZE[1] | ||
down_sample_ratio = self.model_cfg.ROI_GRID_POOL.DOWNSAMPLE_RATIO | ||
|
||
pooled_features_list = [] | ||
torch.backends.cudnn.enabled = False | ||
for b_id in range(batch_size): | ||
# Map global boxes coordinates to feature map coordinates | ||
x1 = (rois[b_id, :, 0] - rois[b_id, :, 3] / 2 - min_x) / (voxel_size_x * down_sample_ratio) | ||
x2 = (rois[b_id, :, 0] + rois[b_id, :, 3] / 2 - min_x) / (voxel_size_x * down_sample_ratio) | ||
y1 = (rois[b_id, :, 1] - rois[b_id, :, 4] / 2 - min_y) / (voxel_size_y * down_sample_ratio) | ||
y2 = (rois[b_id, :, 1] + rois[b_id, :, 4] / 2 - min_y) / (voxel_size_y * down_sample_ratio) | ||
|
||
angle, _ = common_utils.check_numpy_to_torch(rois[b_id, :, 6]) | ||
|
||
cosa = torch.cos(angle) | ||
sina = torch.sin(angle) | ||
|
||
theta = torch.stack(( | ||
(x2 - x1) / (width - 1) * cosa, (x2 - x1) / (width - 1) * (-sina), (x1 + x2 - width + 1) / (width - 1), | ||
(y2 - y1) / (height - 1) * sina, (y2 - y1) / (height - 1) * cosa, (y1 + y2 - height + 1) / (height - 1) | ||
), dim=1).view(-1, 2, 3).float() | ||
|
||
grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE | ||
grid = nn.functional.affine_grid( | ||
theta, | ||
torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size)) | ||
) | ||
|
||
pooled_features = nn.functional.grid_sample( | ||
spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width), | ||
grid | ||
) | ||
|
||
pooled_features_list.append(pooled_features) | ||
|
||
torch.backends.cudnn.enabled = True | ||
pooled_features = torch.cat(pooled_features_list, dim=0) | ||
|
||
return pooled_features | ||
|
||
def forward(self, batch_dict): | ||
""" | ||
:param input_data: input dict | ||
:return: | ||
""" | ||
targets_dict = self.proposal_layer( | ||
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST'] | ||
) | ||
if self.training: | ||
targets_dict = self.assign_targets(batch_dict) | ||
batch_dict['rois'] = targets_dict['rois'] | ||
batch_dict['roi_labels'] = targets_dict['roi_labels'] | ||
|
||
# RoI aware pooling | ||
pooled_features = self.roi_grid_pool(batch_dict) # (BxN, C, 7, 7) | ||
batch_size_rcnn = pooled_features.shape[0] | ||
|
||
shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1)) | ||
rcnn_iou = self.iou_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B*N, 1) | ||
|
||
if not self.training: | ||
batch_dict['batch_cls_preds'] = rcnn_iou.view(batch_dict['batch_size'], -1, rcnn_iou.shape[-1]) | ||
batch_dict['batch_box_preds'] = batch_dict['rois'] | ||
batch_dict['cls_preds_normalized'] = False | ||
else: | ||
targets_dict['rcnn_iou'] = rcnn_iou | ||
|
||
self.forward_ret_dict = targets_dict | ||
|
||
return batch_dict | ||
|
||
def get_loss(self, tb_dict=None): | ||
tb_dict = {} if tb_dict is None else tb_dict | ||
rcnn_loss = 0 | ||
rcnn_loss_cls, cls_tb_dict = self.get_box_iou_layer_loss(self.forward_ret_dict) | ||
rcnn_loss += rcnn_loss_cls | ||
tb_dict.update(cls_tb_dict) | ||
|
||
tb_dict['rcnn_loss'] = rcnn_loss.item() | ||
return rcnn_loss, tb_dict | ||
|
||
def get_box_iou_layer_loss(self, forward_ret_dict): | ||
loss_cfgs = self.model_cfg.LOSS_CONFIG | ||
rcnn_iou = forward_ret_dict['rcnn_iou'] | ||
rcnn_iou_labels = forward_ret_dict['rcnn_cls_labels'].view(-1) | ||
rcnn_iou_flat = rcnn_iou.view(-1) | ||
if loss_cfgs.IOU_LOSS == 'BinaryCrossEntropy': | ||
batch_loss_iou = nn.functional.binary_cross_entropy_with_logits( | ||
rcnn_iou_flat, | ||
rcnn_iou_labels.float(), reduction='none' | ||
) | ||
elif loss_cfgs.IOU_LOSS == 'L2': | ||
batch_loss_iou = nn.functional.mse_loss(rcnn_iou_flat, rcnn_iou_labels, reduction='none') | ||
elif loss_cfgs.IOU_LOSS == 'smoothL1': | ||
diff = rcnn_iou_flat - rcnn_iou_labels | ||
batch_loss_iou = loss_utils.WeightedSmoothL1Loss.smooth_l1_loss(diff, 1.0 / 9.0) | ||
elif loss_cfgs.IOU_LOSS == 'focalbce': | ||
batch_loss_iou = loss_utils.sigmoid_focal_cls_loss(rcnn_iou_flat, rcnn_iou_labels) | ||
else: | ||
raise NotImplementedError | ||
|
||
iou_valid_mask = (rcnn_iou_labels >= 0).float() | ||
rcnn_loss_iou = (batch_loss_iou * iou_valid_mask).sum() / torch.clamp(iou_valid_mask.sum(), min=1.0) | ||
|
||
rcnn_loss_iou = rcnn_loss_iou * loss_cfgs.LOSS_WEIGHTS['rcnn_iou_weight'] | ||
tb_dict = {'rcnn_loss_iou': rcnn_loss_iou.item()} | ||
return rcnn_loss_iou, tb_dict |
Oops, something went wrong.
e3bec15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided model for SECOND_IOU cannot be reproduced. Any other information about SCORE_TYPE?
e3bec15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check it later.
e3bec15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reproduce the result successfully. Can you reproduce the result of second?
2021-05-25 16:48:09,630 INFO *************** EPOCH 80 EVALUATION *****************
eval: 100%|██████████| 118/118 [00:30<00:00, 3.81it/s, recall_0.3=(2136, 2136) / 2210]
2021-05-25 16:48:41,134 INFO *************** Performance of EPOCH 80 *****************
2021-05-25 16:48:41,137 INFO Generate label finished(sec_per_example: 0.0084 second).
2021-05-25 16:48:41,137 INFO recall_roi_0.3: 0.971216
2021-05-25 16:48:41,137 INFO recall_rcnn_0.3: 0.971216
2021-05-25 16:48:41,137 INFO recall_roi_0.5: 0.933273
2021-05-25 16:48:41,137 INFO recall_rcnn_0.5: 0.933273
2021-05-25 16:48:41,138 INFO recall_roi_0.7: 0.712612
2021-05-25 16:48:41,138 INFO recall_rcnn_0.7: 0.712612
2021-05-25 16:48:41,141 INFO Average predicted number of objects(3769 samples): 13.678
2021-05-25 16:49:09,050 INFO Car AP@0.70, 0.70, 0.70:
bbox AP:90.8298, 89.8464, 89.2890
bev AP:90.2490, 88.0956, 87.1298
3d AP:89.2254, 79.1751, 78.2560
aos AP:90.81, 89.73, 89.10
Car AP_R40@0.70, 0.70, 0.70:
bbox AP:95.9564, 94.4257, 92.0571
bev AP:92.8637, 89.9651, 88.1752
3d AP:91.1837, 82.4682, 79.8206
aos AP:95.92, 94.28, 91.84
Car AP@0.70, 0.50, 0.50:
bbox AP:90.8298, 89.8464, 89.2890
bev AP:90.8229, 89.9098, 89.4771
3d AP:90.8229, 89.8907, 89.4380
aos AP:90.81, 89.73, 89.10
Car AP_R40@0.70, 0.50, 0.50:
bbox AP:95.9564, 94.4257, 92.0571
bev AP:95.9862, 94.7134, 94.3089
3d AP:95.9769, 94.6662, 94.1957
aos AP:95.92, 94.28, 91.84
Pedestrian AP@0.50, 0.50, 0.50:
bbox AP:72.6088, 67.1919, 64.0639
bev AP:67.5701, 60.5777, 56.7529
3d AP:62.7718, 56.6545, 51.3505
aos AP:68.24, 62.80, 59.55
Pedestrian AP_R40@0.50, 0.50, 0.50:
bbox AP:73.3103, 67.7723, 64.1540
bev AP:67.7177, 60.3325, 55.6805
3d AP:63.2151, 55.7024, 50.3015
aos AP:68.40, 62.82, 59.10
Pedestrian AP@0.50, 0.25, 0.25:
bbox AP:72.6088, 67.1919, 64.0639
bev AP:79.2449, 74.9767, 71.1995
3d AP:78.8350, 74.7658, 70.7464
aos AP:68.24, 62.80, 59.55
Pedestrian AP_R40@0.50, 0.25, 0.25:
bbox AP:73.3103, 67.7723, 64.1540
bev AP:80.5269, 76.2809, 72.2288
3d AP:80.2924, 76.0266, 71.6252
aos AP:68.40, 62.82, 59.10
Cyclist AP@0.50, 0.50, 0.50:
bbox AP:92.6116, 82.0762, 78.2363
bev AP:91.3906, 75.8321, 71.1691
3d AP:86.9092, 70.9220, 66.8777
aos AP:92.29, 81.11, 77.37
Cyclist AP_R40@0.50, 0.50, 0.50:
bbox AP:95.1193, 84.1542, 79.6162
bev AP:93.6958, 77.0219, 72.1574
3d AP:90.3740, 71.6101, 67.0025
aos AP:94.78, 83.16, 78.63
Cyclist AP@0.50, 0.25, 0.25:
bbox AP:92.6116, 82.0762, 78.2363
bev AP:91.8523, 79.3298, 75.3813
3d AP:91.8523, 79.3298, 75.3813
aos AP:92.29, 81.11, 77.37
Cyclist AP_R40@0.50, 0.25, 0.25:
bbox AP:95.1193, 84.1542, 79.6162
bev AP:94.4396, 81.0085, 76.5129
3d AP:94.4396, 81.0085, 76.5129
aos AP:94.78, 83.16, 78.63
e3bec15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also able to reproduce the numbers.
Shown here are the R40 results (easy, moderate, hard from left to right).