Skip to content

Fail to backprop through LQRStepFn #47

@anby-dmr

Description

@anby-dmr

When I tried to do loss back propagation through LQRStepFn, an error occured:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 10
---> 10 loss.backward()

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\_tensor.py:521, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    511 if has_torch_function_unary(self):
    512     return handle_torch_function(
    513         Tensor.backward,
    514         (self,),
   (...)
    519         inputs=inputs,
    520     )
--> 521 torch.autograd.backward(
    522     self, gradient, retain_graph, create_graph, inputs=inputs
    523 )

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\__init__.py:289, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    284     retain_graph = create_graph
    286 # The reason we repeat the same comment below is that
    287 # some Python versions print out the first line of a multi-line function
    288 # calls in the traceback and some print out the last line
--> 289 _engine_run_backward(
    290     tensors,
    291     grad_tensors_,
    292     retain_graph,
    293     create_graph,
    294     inputs,
    295     allow_unreachable=True,
    296     accumulate_grad=True,
    297 )

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\graph.py:769, in _engine_run_backward(t_outputs, *args, **kwargs)
    767     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    768 try:
--> 769     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    770         t_outputs, *args, **kwargs
    771     )  # Calls into the C++ engine to run the backward pass
    772 finally:
    773     if attach_logging_hooks:

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\function.py:306, in BackwardCFunction.apply(self, *args)
    300     raise RuntimeError(
    301         "Implementing both 'backward' and 'vjp' for a custom "
    302         "Function is not allowed. You should only implement one "
    303         "of them."
    304     )
    305 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 306 return user_fn(self, *args)

TypeError: backward() takes from 3 to 5 positional arguments but 7 were given

I guess the problem is in lqr_step.py - class LQRStepFn, where the number of parameters of forward() and backward() do not match properly:

    class LQRStepFn(Function):
        # @profile
        @staticmethod
        def forward(ctx, x_init, C, c, F, f=None):
            ........
            return new_x, new_u, torch.Tensor([n_total_qp_iter]), \
              for_out.costs, for_out.full_du_norm, for_out.mean_alphas

        @staticmethod
        def backward(ctx, dl_dx, dl_du, temp=None, temp2=None):
            start = time.time()
            x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors

In the above code snippet, forward() returns 6 parameters but backward() only accept 4 parameters. After adding 2 arguments to backward(), the loss can backprop without reporting an error:

        @staticmethod
        def backward(ctx, dl_dx, dl_du, temp=None, temp2=None, temp3=None, temp4=None):
            start = time.time()
            x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions