diff --git a/polymetis/polymetis/conf/robot_client/franka_hardware.yaml b/polymetis/polymetis/conf/robot_client/franka_hardware.yaml index 2d13c1e902..483b588121 100644 --- a/polymetis/polymetis/conf/robot_client/franka_hardware.yaml +++ b/polymetis/polymetis/conf/robot_client/franka_hardware.yaml @@ -10,7 +10,7 @@ robot_client: default_Kq: [40, 30, 50, 25, 35, 25, 10] default_Kqd: [4, 6, 5, 5, 3, 2, 1] default_Kx: [750, 750, 750, 15, 15, 15] - default_Kxd: [37, 37, 37, 2, 2, 2] + default_Kxd: [60, 60, 60, 8, 8, 8] hz: ${hz} robot_model_cfg: ${robot_model} executable_cfg: diff --git a/polymetis/polymetis/python/polymetis/base_interface.py b/polymetis/polymetis/python/polymetis/base_interface.py new file mode 100644 index 0000000000..9f594a69d1 --- /dev/null +++ b/polymetis/polymetis/python/polymetis/base_interface.py @@ -0,0 +1,269 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import io +from typing import Dict, Generator, List, Callable +import time +import threading +import atexit +import logging + +import grpc # This requires `conda install grpcio protobuf` +import torch + +import polymetis +from polymetis_pb2 import LogInterval, RobotState, ControllerChunk, Empty +from polymetis_pb2_grpc import PolymetisControllerServerStub +import torchcontrol as toco + +log = logging.getLogger(__name__) + + +# Maximum bytes we send per message to server (so as not to overload it). +MAX_BYTES_PER_MSG = 1024 + +# Polling rate when waiting for episode to finish +POLLING_RATE = 50 + +# Grpc empty object +EMPTY = Empty() + + +# Dict container as a nn.module to enable use of jit.save & jit.load +class ParamDictContainer(torch.nn.Module): + """A torch.nn.Module container for a parameter key. + + Note: + This is necessary because TorchScript can only script modules, + not tensors or dictionaries. + + Args: + param_dict: The dictionary mapping parameter names to values. + """ + + param_dict: Dict[str, torch.Tensor] + + def __init__(self, param_dict: Dict[str, torch.Tensor]): + super().__init__() + self.param_dict = param_dict + + def forward(self) -> Dict[str, torch.Tensor]: + """Simply returns the wrapped parameter dictionary.""" + return self.param_dict + + +class BaseRobotInterface: + """Base robot interface class to initialize a connection to a gRPC controller manager server. + + Args: + ip_address: IP address of the gRPC-based controller manager server. + port: Port to connect to on the IP address. + """ + + def __init__( + self, ip_address: str = "localhost", port: int = 50051, enforce_version=True + ): + # Create connection + self.channel = grpc.insecure_channel(f"{ip_address}:{port}") + self.grpc_connection = PolymetisControllerServerStub(self.channel) + + # Get metadata + self.metadata = self.grpc_connection.GetRobotClientMetadata(EMPTY) + + # Check version + if enforce_version: + client_ver = polymetis.__version__ + server_ver = self.metadata.polymetis_version + assert ( + client_ver == server_ver + ), "Version mismatch between client & server detected! Set enforce_version=False to bypass this error." + + def __del__(self): + # Close connection in destructor + self.channel.close() + + @staticmethod + def _get_msg_generator(scripted_module) -> Generator: + """Given a scripted module, return a generator of its serialized bits + as byte chunks of max size MAX_BYTES_PER_MSG.""" + # Write into bytes buffer + buffer = io.BytesIO() + torch.jit.save(scripted_module, buffer) + buffer.seek(0) + + # Create policy generator + def msg_generator(): + # A generator which chunks a scripted module into messages of + # size MAX_BYTES_PER_MSG and send these messages to the server. + while True: + chunk = buffer.read(MAX_BYTES_PER_MSG) + if not chunk: # end of buffer + break + msg = ControllerChunk(torchscript_binary_chunk=chunk) + yield msg + + return msg_generator + + def _get_robot_state_log( + self, log_interval: LogInterval, timeout: float = None + ) -> List[RobotState]: + """A private helper method to get the states corresponding to a log_interval from the server. + + Args: + log_interval: a message holding start and end indices for a trajectory of RobotStates. + timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. + + Returns: + If successful, returns a list of RobotState objects. + + """ + robot_state_generator = self.grpc_connection.GetRobotStateLog(log_interval) + + def cancel_rpc(): + log.info("Cancelling attempt to get robot state log.") + robot_state_generator.cancel() + log.info(f"Cancellation completed.") + + atexit.register(cancel_rpc) + + results = [] + + def read_stream(): + try: + for state in robot_state_generator: + results.append(state) + except grpc.RpcError as e: + log.error(f"Unable to read stream of robot states: {e}") + + read_thread = threading.Thread(target=read_stream) + read_thread.start() + read_thread.join(timeout=timeout) + + if read_thread.is_alive(): + raise TimeoutError("Operation timed out.") + else: + atexit.unregister(cancel_rpc) + return results + + def get_robot_state(self) -> RobotState: + """Returns the latest RobotState.""" + return self.grpc_connection.GetRobotState(EMPTY) + + def get_previous_interval(self, timeout: float = None) -> LogInterval: + """Get the log indices associated with the currently running policy.""" + log_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) + assert log_interval.start != -1, "Cannot find previous episode." + return log_interval + + def is_running_policy(self) -> bool: + log_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) + return ( + log_interval.start != -1 # policy has started + and log_interval.end == -1 # policy has not ended + ) + + def get_previous_log(self, timeout: float = None) -> List[RobotState]: + """Get the list of RobotStates associated with the currently running policy. + + Args: + timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. + + Returns: + If successful, returns a list of RobotState objects. + + """ + log_interval = self.get_previous_interval(timeout) + return self._get_robot_state_log(log_interval, timeout=timeout) + + def send_torch_policy( + self, + torch_policy: toco.PolicyModule, + blocking: bool = True, + timeout: float = None, + post_exe_hook: Callable = None, + ) -> List[RobotState]: + """Sends the ScriptableTorchPolicy to the server. + + Args: + torch_policy: An instance of ScriptableTorchPolicy to control the robot. + blocking: If True, blocks until the policy is finished executing, then returns the list of RobotStates. + timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. + + Returns: + If `blocking`, returns a list of RobotState objects. Otherwise, returns None. + + """ + start_time = time.time() + + # Script & chunk policy + scripted_policy = torch.jit.script(torch_policy) + msg_generator = self._get_msg_generator(scripted_policy) + + # Send policy as stream + try: + log_interval = self.grpc_connection.SetController(msg_generator()) + except grpc.RpcError as e: + raise grpc.RpcError(f"POLYMETIS SERVER ERROR --\n{e.details()}") from None + + if blocking: + # Check policy termination + while log_interval.end == -1: + log_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) + + if timeout is not None and time.time() - start_time > timeout: + raise TimeoutError("Operation timed out.") + time.sleep(1.0 / POLLING_RATE) + + # Execute post-execution hook + if post_exe_hook is not None: + post_exe_hook() + + # Retrieve robot state log + if timeout is not None: + time_passed = time.time() - start_time + timeout = timeout - time_passed + return self._get_robot_state_log(log_interval, timeout=timeout) + + def update_current_policy(self, param_dict: Dict[str, torch.Tensor]) -> int: + """Updates the current policy's with a (possibly incomplete) dictionary holding the updated values. + + Args: + param_dict: A dictionary mapping from param_name to updated torch.Tensor values. + + Returns: + Index offset from the beginning of the episode when the update was applied. + + """ + # Script & chunk params + scripted_params = torch.jit.script(ParamDictContainer(param_dict)) + msg_generator = self._get_msg_generator(scripted_params) + + # Send params container as stream + try: + update_interval = self.grpc_connection.UpdateController(msg_generator()) + except grpc.RpcError as e: + raise grpc.RpcError(f"POLYMETIS SERVER ERROR --\n{e.details()}") from None + episode_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) + + return update_interval.start - episode_interval.start + + def terminate_current_policy( + self, return_log: bool = True, timeout: float = None + ) -> List[RobotState]: + """Terminates the currently running policy and (optionally) return its trajectory. + + Args: + return_log: whether or not to block & return the policy's trajectory. + timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. + + Returns: + If `return_log`, returns the list of RobotStates the list of RobotStates corresponding to the current policy's execution. + + """ + # Send termination + log_interval = self.grpc_connection.TerminateController(EMPTY) + + # Query episode log + if return_log: + return self._get_robot_state_log(log_interval, timeout=timeout) diff --git a/polymetis/polymetis/python/polymetis/robot_interface.py b/polymetis/polymetis/python/polymetis/robot_interface.py index b5510b8510..afad2238ac 100644 --- a/polymetis/polymetis/python/polymetis/robot_interface.py +++ b/polymetis/polymetis/python/polymetis/robot_interface.py @@ -2,18 +2,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import io -from typing import Dict, Generator, List, Tuple +from functools import partial +from typing import List, Tuple, Optional import time import tempfile -import threading -import atexit import logging +from dataclasses import dataclass import grpc # This requires `conda install grpcio protobuf` import torch import polymetis +from polymetis.base_interface import BaseRobotInterface from polymetis_pb2 import LogInterval, RobotState, ControllerChunk, Empty from polymetis_pb2_grpc import PolymetisControllerServerStub @@ -24,248 +24,13 @@ log = logging.getLogger(__name__) -# Maximum bytes we send per message to server (so as not to overload it). -MAX_BYTES_PER_MSG = 1024 - -# Polling rate when waiting for episode to finish -POLLING_RATE = 50 - -# Grpc empty object -EMPTY = Empty() - - -# Dict container as a nn.module to enable use of jit.save & jit.load -class ParamDictContainer(torch.nn.Module): - """A torch.nn.Module container for a parameter key. - - Note: - This is necessary because TorchScript can only script modules, - not tensors or dictionaries. - - Args: - param_dict: The dictionary mapping parameter names to values. - """ - - param_dict: Dict[str, torch.Tensor] - - def __init__(self, param_dict: Dict[str, torch.Tensor]): - super().__init__() - self.param_dict = param_dict - - def forward(self) -> Dict[str, torch.Tensor]: - """Simply returns the wrapped parameter dictionary.""" - return self.param_dict - - -class BaseRobotInterface: - """Base robot interface class to initialize a connection to a gRPC controller manager server. - - Args: - ip_address: IP address of the gRPC-based controller manager server. - port: Port to connect to on the IP address. - """ - - def __init__( - self, ip_address: str = "localhost", port: int = 50051, enforce_version=True - ): - # Create connection - self.channel = grpc.insecure_channel(f"{ip_address}:{port}") - self.grpc_connection = PolymetisControllerServerStub(self.channel) - - # Get metadata - self.metadata = self.grpc_connection.GetRobotClientMetadata(EMPTY) - - # Check version - if enforce_version: - client_ver = polymetis.__version__ - server_ver = self.metadata.polymetis_version - assert ( - client_ver == server_ver - ), "Version mismatch between client & server detected! Set enforce_version=False to bypass this error." - - def __del__(self): - # Close connection in destructor - self.channel.close() - - @staticmethod - def _get_msg_generator(scripted_module) -> Generator: - """Given a scripted module, return a generator of its serialized bits - as byte chunks of max size MAX_BYTES_PER_MSG.""" - # Write into bytes buffer - buffer = io.BytesIO() - torch.jit.save(scripted_module, buffer) - buffer.seek(0) - - # Create policy generator - def msg_generator(): - # A generator which chunks a scripted module into messages of - # size MAX_BYTES_PER_MSG and send these messages to the server. - while True: - chunk = buffer.read(MAX_BYTES_PER_MSG) - if not chunk: # end of buffer - break - msg = ControllerChunk(torchscript_binary_chunk=chunk) - yield msg - - return msg_generator - - def _get_robot_state_log( - self, log_interval: LogInterval, timeout: float = None - ) -> List[RobotState]: - """A private helper method to get the states corresponding to a log_interval from the server. - - Args: - log_interval: a message holding start and end indices for a trajectory of RobotStates. - timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. - - Returns: - If successful, returns a list of RobotState objects. - - """ - robot_state_generator = self.grpc_connection.GetRobotStateLog(log_interval) - - def cancel_rpc(): - log.info("Cancelling attempt to get robot state log.") - robot_state_generator.cancel() - log.info(f"Cancellation completed.") - - atexit.register(cancel_rpc) - - results = [] - - def read_stream(): - try: - for state in robot_state_generator: - results.append(state) - except grpc.RpcError as e: - log.error(f"Unable to read stream of robot states: {e}") - - read_thread = threading.Thread(target=read_stream) - read_thread.start() - read_thread.join(timeout=timeout) - - if read_thread.is_alive(): - raise TimeoutError("Operation timed out.") - else: - atexit.unregister(cancel_rpc) - return results - - def get_robot_state(self) -> RobotState: - """Returns the latest RobotState.""" - return self.grpc_connection.GetRobotState(EMPTY) - - def get_previous_interval(self, timeout: float = None) -> LogInterval: - """Get the log indices associated with the currently running policy.""" - log_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) - assert log_interval.start != -1, "Cannot find previous episode." - return log_interval - - def is_running_policy(self) -> bool: - log_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) - return ( - log_interval.start != -1 # policy has started - and log_interval.end == -1 # policy has not ended - ) - - def get_previous_log(self, timeout: float = None) -> List[RobotState]: - """Get the list of RobotStates associated with the currently running policy. - - Args: - timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. - - Returns: - If successful, returns a list of RobotState objects. - - """ - log_interval = self.get_previous_interval(timeout) - return self._get_robot_state_log(log_interval, timeout=timeout) - - def send_torch_policy( - self, - torch_policy: toco.PolicyModule, - blocking: bool = True, - timeout: float = None, - ) -> List[RobotState]: - """Sends the ScriptableTorchPolicy to the server. - - Args: - torch_policy: An instance of ScriptableTorchPolicy to control the robot. - blocking: If True, blocks until the policy is finished executing, then returns the list of RobotStates. - timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. - - Returns: - If `blocking`, returns a list of RobotState objects. Otherwise, returns None. - - """ - start_time = time.time() - - # Script & chunk policy - scripted_policy = torch.jit.script(torch_policy) - msg_generator = self._get_msg_generator(scripted_policy) - - # Send policy as stream - try: - log_interval = self.grpc_connection.SetController(msg_generator()) - except grpc.RpcError as e: - raise grpc.RpcError(f"POLYMETIS SERVER ERROR --\n{e.details()}") from None - - if blocking: - # Check policy termination - while log_interval.end == -1: - log_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) - - if timeout is not None and time.time() - start_time > timeout: - raise TimeoutError("Operation timed out.") - time.sleep(1.0 / POLLING_RATE) - - # Retrieve robot state log - if timeout is not None: - time_passed = time.time() - start_time - timeout = timeout - time_passed - return self._get_robot_state_log(log_interval, timeout=timeout) - - def update_current_policy(self, param_dict: Dict[str, torch.Tensor]) -> int: - """Updates the current policy's with a (possibly incomplete) dictionary holding the updated values. - - Args: - param_dict: A dictionary mapping from param_name to updated torch.Tensor values. - - Returns: - Index offset from the beginning of the episode when the update was applied. - - """ - # Script & chunk params - scripted_params = torch.jit.script(ParamDictContainer(param_dict)) - msg_generator = self._get_msg_generator(scripted_params) - - # Send params container as stream - try: - update_interval = self.grpc_connection.UpdateController(msg_generator()) - except grpc.RpcError as e: - raise grpc.RpcError(f"POLYMETIS SERVER ERROR --\n{e.details()}") from None - episode_interval = self.grpc_connection.GetEpisodeInterval(EMPTY) - - return update_interval.start - episode_interval.start - - def terminate_current_policy( - self, return_log: bool = True, timeout: float = None - ) -> List[RobotState]: - """Terminates the currently running policy and (optionally) return its trajectory. - - Args: - return_log: whether or not to block & return the policy's trajectory. - timeout: Amount of time (in seconds) to wait before throwing a TimeoutError. - - Returns: - If `return_log`, returns the list of RobotStates the list of RobotStates corresponding to the current policy's execution. - - """ - # Send termination - log_interval = self.grpc_connection.TerminateController(EMPTY) - - # Query episode log - if return_log: - return self._get_robot_state_log(log_interval, timeout=timeout) +@dataclass +class DefaultControllerConfig: + q_des: torch.Tensor + Kq: torch.Tensor + Kqd: torch.Tensor + Kx: torch.Tensor + Kxd: torch.Tensor class RobotInterface(BaseRobotInterface): @@ -307,6 +72,16 @@ def __init__( self.use_grav_comp = use_grav_comp + # Initialize reference states + self._def_controller_cfg = DefaultControllerConfig( + q_des=self.home_pose, + Kq=self.Kq_default, + Kqd=self.Kqd_default, + Kx=self.Kx_default, + Kxd=self.Kxd_default, + ) + self._reset_default_controller() + def _adaptive_time_to_go(self, joint_displacement: torch.Tensor): """Compute adaptive time_to_go Computes the corresponding time_to_go such that the mean velocity is equal to one-eighth @@ -322,6 +97,48 @@ def _adaptive_time_to_go(self, joint_displacement: torch.Tensor): time_to_go = torch.max(joint_pos_diff / joint_vel_limits * 8.0) return max(time_to_go, self.time_to_go_default) + def _reset_default_controller(self): + if not self.is_running_policy(): + self._def_controller_cfg.q_des = self.get_joint_positions() + self._set_default_controller() + + def _set_default_controller( + self, + joint_pos_desired: Optional[torch.Tensor] = None, + Kq: Optional[torch.Tensor] = None, + Kqd: Optional[torch.Tensor] = None, + Kx: Optional[torch.Tensor] = None, + Kxd: Optional[torch.Tensor] = None, + ): + # Update default controller config + if joint_pos_desired is not None: + self._def_controller_cfg.q_des = joint_pos_desired + + for key in ["Kq", "Kqd", "Kx", "Kxd"]: + K = eval(key) + if K is not None: + assert ( + type(K) is torch.Tensor + ), f"Invalid gain type. Has to be torch.Tensor, got {type(K)} instead." + K_old = getattr(self._def_controller_cfg, key) + assert ( + K.shape == K_old.shape + ), f"Invalid gain shape. Got {K.shape} instead of {K_old.shape}" + + setattr(self._def_controller_cfg, key, K) + + # Send updated controller + default_controller = toco.policies.HybridJointImpedanceControl( + joint_pos_current=self._def_controller_cfg.q_des, + Kq=self._def_controller_cfg.Kq, + Kqd=self._def_controller_cfg.Kqd, + Kx=self._def_controller_cfg.Kx, + Kxd=self._def_controller_cfg.Kxd, + robot_model=self.robot_model, + ignore_gravity=self.use_grav_comp, + ) + self.send_torch_policy(default_controller, blocking=False) + def solve_inverse_kinematics( self, position: torch.Tensor, @@ -359,6 +176,11 @@ def set_robot_model(self, robot_description_path: str, ee_link_name: str = None) robot_description_path, ee_link_name ) + def set_control_gains(self, **kwargs): + """Update tracking controller gains.""" + self._set_default_controller(joint_pos_desired=None, **kwargs) + log.warning("Controller gains updated.") + """ Getter methods """ @@ -396,17 +218,12 @@ def move_to_joint_positions( positions: torch.Tensor, time_to_go: float = None, delta: bool = False, - Kq: torch.Tensor = None, - Kqd: torch.Tensor = None, - **kwargs, ) -> List[RobotState]: """Uses JointGoToPolicy to move to the desired positions with the given gains. Args: positions: Desired target joint positions. time_to_go: Amount of time to execute the motion. Uses an adaptive value if not specified (see `_adaptive_time_to_go` for details). delta: Whether the specified `positions` are relative to current pose or absolute. - Kq: Joint P gains for the tracking controller. Uses default values if not specified. - Kqd: Joint D gains for the tracking controller. Uses default values if not specified. Returns: Same as `send_torch_policy` @@ -438,28 +255,38 @@ def move_to_joint_positions( time_to_go=time_to_go, hz=self.hz, ) + joint_pos_desired_final = waypoints[-1]["position"] # Create & execute policy torch_policy = toco.policies.JointTrajectoryExecutor( joint_pos_trajectory=[waypoint["position"] for waypoint in waypoints], joint_vel_trajectory=[waypoint["velocity"] for waypoint in waypoints], - Kq=self.Kq_default if Kq is None else Kq, - Kqd=self.Kqd_default if Kqd is None else Kqd, - Kx=self.Kx_default, - Kxd=self.Kxd_default, + Kq=self._def_controller_cfg.Kq, + Kqd=self._def_controller_cfg.Kqd, + Kx=self._def_controller_cfg.Kx, + Kxd=self._def_controller_cfg.Kxd, robot_model=self.robot_model, ignore_gravity=self.use_grav_comp, ) - return self.send_torch_policy(torch_policy=torch_policy, **kwargs) + self._reset_default_controller() + return self.send_torch_policy( + torch_policy=torch_policy, + post_exe_hook=partial( + self._set_default_controller, joint_pos_desired_final + ), + ) - def go_home(self, *args, **kwargs) -> List[RobotState]: - """Calls move_to_joint_positions to the current home positions.""" + def go_home(self, time_to_go: float = None) -> List[RobotState]: + """Calls move_to_joint_positions to the current home positions. + Args: + time_to_go: Amount of time to execute the motion. Uses an adaptive value if not specified (see `_adaptive_time_to_go` for details). + """ assert ( self.home_pose is not None ), "Home pose not assigned! Call 'set_home_pose()' to enable homing" return self.move_to_joint_positions( - positions=self.home_pose, delta=False, *args, **kwargs + positions=self.home_pose, time_to_go=time_to_go, delta=False ) def move_to_ee_pose( @@ -468,10 +295,7 @@ def move_to_ee_pose( orientation: torch.Tensor = None, time_to_go: float = None, delta: bool = False, - Kx: torch.Tensor = None, - Kxd: torch.Tensor = None, op_space_interp: bool = True, - **kwargs, ) -> List[RobotState]: """Uses an operational space controller to move to a desired end-effector position (and, optionally orientation). Args: @@ -541,20 +365,27 @@ def move_to_ee_pose( robot_model=self.robot_model, home_pose=self.home_pose, ) + joint_pos_desired_final = waypoints[-1]["position"] # Create joint tracking policy and run torch_policy = toco.policies.JointTrajectoryExecutor( joint_pos_trajectory=[waypoint["position"] for waypoint in waypoints], joint_vel_trajectory=[waypoint["velocity"] for waypoint in waypoints], - Kq=self.Kq_default, - Kqd=self.Kqd_default, - Kx=self.Kx_default if Kx is None else Kx, - Kxd=self.Kxd_default if Kxd is None else Kxd, + Kq=self._def_controller_cfg.Kq, + Kqd=self._def_controller_cfg.Kqd, + Kx=self._def_controller_cfg.Kx, + Kxd=self._def_controller_cfg.Kxd, robot_model=self.robot_model, ignore_gravity=self.use_grav_comp, ) - return self.send_torch_policy(torch_policy=torch_policy, **kwargs) + self._reset_default_controller() + return self.send_torch_policy( + torch_policy=torch_policy, + post_exe_hook=partial( + self._set_default_controller, joint_pos_desired_final + ), + ) else: # Use joint space controller to move to joint target @@ -570,57 +401,37 @@ def start_joint_impedance(self, Kq=None, Kqd=None, adaptive=True, **kwargs): """Starts joint position control mode. Runs an non-blocking joint impedance controller. The desired joint positions can be updated using `update_desired_joint_positions` + **This method is being deprecated.** """ - if adaptive: - torch_policy = toco.policies.HybridJointImpedanceControl( - joint_pos_current=self.get_joint_positions(), - Kq=self.Kq_default if Kq is None else Kq, - Kqd=self.Kqd_default if Kqd is None else Kqd, - Kx=self.Kx_default, - Kxd=self.Kxd_default, - robot_model=self.robot_model, - ignore_gravity=self.use_grav_comp, - ) - else: - torch_policy = toco.policies.JointImpedanceControl( - joint_pos_current=self.get_joint_positions(), - Kp=self.Kq_default if Kq is None else Kq, - Kd=self.Kqd_default if Kqd is None else Kqd, - robot_model=self.robot_model, - ignore_gravity=self.use_grav_comp, - ) - - return self.send_torch_policy(torch_policy=torch_policy, blocking=False) + log.warning( + "'start_joint_impedance' is being deprecated as desired states can now be updated directly without sending a policy. Nothing is being sent to the robot." + ) def start_cartesian_impedance(self, Kx=None, Kxd=None, **kwargs): """Starts Cartesian position control mode. Runs an non-blocking Cartesian impedance controller. The desired EE pose can be updated using `update_desired_ee_pose` + **This method is being deprecated.** """ - torch_policy = toco.policies.HybridJointImpedanceControl( - joint_pos_current=self.get_joint_positions(), - Kq=self.Kq_default, - Kqd=self.Kqd_default, - Kx=self.Kx_default if Kx is None else Kx, - Kxd=self.Kxd_default if Kxd is None else Kxd, - robot_model=self.robot_model, - ignore_gravity=self.use_grav_comp, + log.warning( + "'start_cartesian_impedance' is being deprecated as desired states can now be updated directly without sending a policy. Nothing is being sent to the robot." ) - return self.send_torch_policy(torch_policy=torch_policy, blocking=False) - def update_desired_joint_positions(self, positions: torch.Tensor) -> int: """Update the desired joint positions used by the joint position control mode. Requires starting a joint impedance controller with `start_joint_impedance` beforehand. """ + self._reset_default_controller() try: update_idx = self.update_current_policy({"joint_pos_desired": positions}) except grpc.RpcError as e: log.error( - "Unable to update desired joint positions. Use 'start_joint_impedance' to start a joint impedance controller." + "Unable to update desired robot state. Is another policy currently running?" ) raise e + self._def_controller_cfg.q_des = positions.clone() + return update_idx def update_desired_ee_pose( @@ -687,17 +498,17 @@ def get_joint_angles(self) -> torch.Tensor: """Functionally identical to `get_joint_positions`. **This method is being deprecated in favor of `get_joint_positions`.** """ - log.warning( + log.error( "The method 'get_joint_angles' is deprecated, use 'get_joint_positions' instead." ) - return self.get_joint_positions() + raise NotImplementedError def pose_ee(self) -> Tuple[torch.Tensor, torch.Tensor]: """Functionally identical to `get_ee_pose`. **This method is being deprecated in favor of `get_ee_pose`.** """ - log.warning("The method 'pose_ee' is deprecated, use 'get_ee_pose' instead.") - return self.get_ee_pose() + log.error("The method 'pose_ee' is deprecated, use 'get_ee_pose' instead.") + raise NotImplementedError def set_joint_positions( self, desired_positions, *args, **kwargs @@ -705,12 +516,10 @@ def set_joint_positions( """Functionally identical to `move_to_joint_positions`. **This method is being deprecated in favor of `move_to_joint_positions`.** """ - log.warning( + log.error( "The method 'set_joint_positions' is deprecated, use 'move_to_joint_positions' instead." ) - return self.move_to_joint_positions( - positions=desired_positions, *args, **kwargs - ) + raise NotImplementedError def move_joint_positions( self, delta_positions, *args, **kwargs @@ -718,21 +527,19 @@ def move_joint_positions( """Functionally identical to calling `move_to_joint_positions` with the argument `delta=True`. **This method is being deprecated in favor of `move_to_joint_positions`.** """ - log.warning( + log.error( "The method 'set_joint_positions' is deprecated, use 'move_to_joint_positions' with 'delta=True' instead." ) - return self.move_to_joint_positions( - positions=delta_positions, delta=True, *args, **kwargs - ) + raise NotImplementedError def set_ee_pose(self, *args, **kwargs) -> List[RobotState]: """Functionally identical to `move_to_ee_pose`. **This method is being deprecated in favor of `move_to_ee_pose`.** """ - log.warning( + log.error( "The method 'set_ee_pose' is deprecated, use 'move_to_ee_pose' instead." ) - return self.move_to_ee_pose(*args, **kwargs) + raise NotImplementedError def move_ee_xyz( self, displacement: torch.Tensor, use_orient: bool = True, **kwargs @@ -740,7 +547,7 @@ def move_ee_xyz( """Functionally identical to calling `move_to_ee_pose` with the argument `delta=True`. **This method is being deprecated in favor of `move_to_ee_pose`.** """ - log.warning( + log.error( "The method 'move_ee_xyz' is deprecated, use 'move_to_ee_pose' with 'delta=True' instead." ) - return self.move_to_ee_pose(position=displacement, delta=True, **kwargs) + raise NotImplementedError