[WIP] Strongly Connected Components (Find trivial SCCs in components)#5448
[WIP] Strongly Connected Components (Find trivial SCCs in components)#5448ngokulakrish wants to merge 24 commits intorapidsai:mainfrom
Conversation
…src|dst_property_t object
…nh_prim_factory
|
And #5442 updates only 11 files. I see updates in 22 files in this PR. Did you make any updates outside strongly_connected_components_impl.cuh? |
| bwd_only_vertices.resize(count, handle.get_stream()); | ||
| } | ||
|
|
||
| // 4. FWD_OR_BWD = FWD ∪ BWD (needed to compute REMAINDER) |
There was a problem hiding this comment.
You may merge 4 and 5. What we need to compute is UC - (FWD U BWD)
| component_local_min_vertex_ids.size(), | ||
| raft::comms::op_t::MIN, | ||
| handle.get_stream()); | ||
| } |
There was a problem hiding this comment.
Two things to consider here.
- FWD only, BWD only, Remaining - (FWD U BWD) become next unresolved component if there size is larger than 1.
If the size = 1, they are trivial SCCs as well. And some components might have size 0 as well. They should be discarded.
In multi-GPU, we need call device_allreduce to find the total number of vertices in the next unresolved component candidates.
- For the next iteration, we need to update new
unresolved_component_offsets
unresolved_component_vertices
pairs. In multi-GPU, unresolved_compoent_offsets.size() = # global number of unresolved components + 1.
And vertices in each GPU should be placed based on this offset array.
I think this function is the best place to achieve this.
We have an array of size # (old) unresolved components * 4 (or 3 if we disregard SCCs). If we create an array of the same size and set the value to 1 if the global size is > 1 (i.e. new unresolved component) and set the value to 0 if the global size <= 1 (i.e. SCC or empty).
Then, we run thrust::exclusive_scan. Then, we can map the position in the above array to the new unresolved component index (which is used to access the new unresolved_component_offsets).
Then, at the end, we can just resize the return value of this function (to exclude the SCCs) and start the next iteration.
This is an optimization PR for SCC #5442 to recursively trim trivial SCCs in components before running FW-BW pass.
Should be reviewed/merged after #5442