A Pytorch implementation of InfoGAIL built on top of stable-baselines3 and imiation.
Core changes to the imitation repository v0.2.0 are done to implement InfoGAIL We have kept only necessary files from the imitation repository.
Two new classes in src\imitation\rewards\discrim_nets.py
WassersteinDiscrimNet: InheritsDiscrimNetand overwritesdisc_lossthat implements the Wasserstein loss to train the discriminatorDiscrimNetWGAIL: InheritsWassersteinDiscrimNetand overwritesreward_trainwith -logits as the reward for the generator.
Two new classes in src\imitation\algorithms\adversarial.py
WGAIL: Core changes fromGAILclass areDiscrimNetWGAILas the discriminator anddisc_opt_clsas RMSprop instead of AdamWassersteinAdversarialTrainer: inheritsAdversarialTrainerclass to include gradient clipping in thetrain_discfunction
Sample test script for WGAIL: python .\minigrid_wgail_training_script.py -r testing_wgail -t minigrid_empty_right_down -f --vis-trained
Policy was consistent even if env was changed from "MiniGrid-Empty-6x6-v0" to "MiniGrid-Empty-8x8-v0" and "MiniGrid-Empty-5x5-v0" while testing
To avoid any more core changes to the imitation library, all classes needed to execute a CNN version of GAIL and WGAIL are saved in the cnn_modules folder.
Two new discriminator classes in cnn_modules/cnn_discriminator.py
ActObsCNN: uses a NaturCNN backbone from stable-baselines 3 to extract features from an image observation. Obs features are concatenated with the action and rest is asActObsMLPwould work.ObsOnlyCNN: same asActObsCNN, no action is used.
To use the CNN version of GAIL or WGAIL, exclude the -f arg.
A sample test script for CNN-GAIL: python .\minigrid_gail_training_script.py -r testing_cnngail -t img_no_stack_minigrid_empty_down_right --vis-trained