Skip to content

Refactor neuron changes to make it compatible with FS repo#4

Open
aws-mengchiy wants to merge 5 commits intomainfrom
new_clean_code
Open

Refactor neuron changes to make it compatible with FS repo#4
aws-mengchiy wants to merge 5 commits intomainfrom
new_clean_code

Conversation

@aws-mengchiy
Copy link
Collaborator

Delete unnecessary snippets and rewrite neuron changes to make it compatible with FS repo. It intends to minimize the neuron-specific changes we made as well.

@aws-mengchiy aws-mengchiy marked this pull request as ready for review March 27, 2024 17:18
ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, None)

if jax.default_backend() == "neuron":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt this line be before we create the specs above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like it should. Maybe python compiler plugs in the method inline so that I did not hit any error. But will move it up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants