Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions scripts/intprim_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import intprim.basis.polynomial_model
import intprim.basis.sigmoidal_model
import intprim.basis.selection
import intprim.util.gaussian
# import intprim.util.gaussian
import intprim_framework_ros.msg
import intprim_framework_ros.srv
import itertools
Expand Down Expand Up @@ -675,7 +675,7 @@ def get_statistics_callback(self, request):
self.statistics_publisher.publish(message)

# Export debugging XML file here.
if(self.bip_parameters[0]["debug"]):
if(self.bip_parameters[request.interaction_id]["debug"]):
self.stat_collector.export(self.bip_instances[request.interaction_id], self.bip_parameters[request.interaction_id]["debug_directory"], request.bag_file, self.bip_parameters[request.interaction_id]["num_samples"])

# Return values as part of service call as well
Expand Down Expand Up @@ -708,11 +708,11 @@ def initialize_state_callback(self, request):
self.initialize_state()

# Initialize stat collection for debugging
if(self.bip_parameters[0]["debug"]):
self.stat_collector = analysis.stat_collector.StatCollector(self.bip_instances[0], self.bip_parameters[0]["generate_indices"], np.setdiff1d(self.bip_parameters[0]["all_active_dofs"], self.bip_parameters[0]["generate_indices"]))
if(self.bip_parameters[self.primary_instance]["debug"]):
self.stat_collector = analysis.stat_collector.StatCollector(self.bip_instances[self.primary_instance], self.bip_parameters[self.primary_instance]["generate_indices"], np.setdiff1d(self.bip_parameters[self.primary_instance]["all_active_dofs"], self.bip_parameters[self.primary_instance]["generate_indices"]))

generated_trajectory = self.bip_instances[0].get_mean_trajectory(num_samples = self.bip_parameters[0]["num_samples"])
self.stat_collector.collect(self.bip_instances[0], np.array([[] for _ in range(generated_trajectory.shape[0])]), generated_trajectory.T, None)
generated_trajectory = self.bip_instances[self.primary_instance].get_mean_trajectory(num_samples = self.bip_parameters[self.primary_instance]["num_samples"])
self.stat_collector.collect(self.bip_instances[self.primary_instance], np.array([[] for _ in range(generated_trajectory.shape[0])]), generated_trajectory.T, None)

return intprim_framework_ros.srv.InitializeStateResponse(True)

Expand Down