Skip to content
Snippets Groups Projects
Commit da31bb76 authored by Nikos Athanasiou's avatar Nikos Athanasiou
Browse files

disable logging metrics debug cluster

parent da66afc8
No related branches found
No related tags found
No related merge requests found
......@@ -500,35 +500,36 @@ class TEACH(BaseModel):
# # Compute the metrics
# # breakpoint()
# # output_features_1_T_lst[0] -> F, feats -> F, xyz
if split == "val":# or batch_idx == 0:
self.transforms.rots2rfeats.to(latent_vector_0_T.device)
self.transforms.rots2joints.to(latent_vector_0_T.device)
# if split == "val":# or batch_idx == 0:
# self.transforms.rots2rfeats.to(latent_vector_0_T.device)
# self.transforms.rots2joints.to(latent_vector_0_T.device)
# if not self.hparams.losses.loss_on_transition:
# joints_0 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in output_features_0_T_lst]
# joints_1 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in output_features_1_T_lst]
# ref_joints_0 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in input_motion_feats_0_lst]
# ref_joints_1 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in input_motion_feats_1_lst]
# self.metrics_0.update(joints_0.detach(), ref_joints_0.detach(), length_0.detach())
# self.metrics_1.update(joints_1.detach(), ref_joints_1.detach(), length_1.detach())
# else:
# # if self.motiondecoder.hparams.prev_data_mode == "hist_frame_outpast":
# # # remove the hframes
# # # M0 / M1[hframes:]
# # breakpoint()
# # joints_output = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(torch.cat((x, y[hframes:]))))
# # for x, y in zip(output_features_0_T_lst, output_features_1_T_with_transition_lst)]
# # else:
# # Evalute on all the joints (concatenation)
# joints_output = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(torch.cat((x.detach(), y.detach()))))
# for x, y in zip(output_features_0_T_lst, output_features_1_T_with_transition_lst)]
# joints_input = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(torch.cat((x.detach(), y.detach()))))
# for x, y in zip(input_motion_feats_0_lst, input_motion_feats_1_with_transition_lst)]
# self.metrics.update(joints_output, joints_input, total_length)
if not self.hparams.losses.loss_on_transition:
joints_0 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in output_features_0_T_lst]
joints_1 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in output_features_1_T_lst]
ref_joints_0 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in input_motion_feats_0_lst]
ref_joints_1 = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(x.detach())) for x in input_motion_feats_1_lst]
self.metrics_0.update(joints_0.detach(), ref_joints_0.detach(), length_0.detach())
self.metrics_1.update(joints_1.detach(), ref_joints_1.detach(), length_1.detach())
else:
# if self.motiondecoder.hparams.prev_data_mode == "hist_frame_outpast":
# # remove the hframes
# # M0 / M1[hframes:]
# breakpoint()
# joints_output = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(torch.cat((x, y[hframes:]))))
# for x, y in zip(output_features_0_T_lst, output_features_1_T_with_transition_lst)]
# else:
# Evalute on all the joints (concatenation)
joints_output = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(torch.cat((x.detach(), y.detach()))))
for x, y in zip(output_features_0_T_lst, output_features_1_T_with_transition_lst)]
joints_input = [self.transforms.rots2joints(self.transforms.rots2rfeats.inverse(torch.cat((x.detach(), y.detach()))))
for x, y in zip(input_motion_feats_0_lst, input_motion_feats_1_with_transition_lst)]
self.metrics.update(joints_output, joints_input, total_length)
# render_list_train = ['1574-1', '7286-0', '6001-0', '4224-2', '3415-0', '2634-0', '2424-1', '4550-0']
# render_list_val = ['2307-1', '6078-0', '5210-0', '12255-0', '11346-2', '11671-1', '443-8', '3290-3', '2014-0', '973-12']
# breakpoint()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment