xgboost
|
Experimental support for exposing internal communicator in XGBoost. More...
Typedefs | |
typedef void * | TrackerHandle |
Handle to the tracker. More... | |
Functions | |
int | XGTrackerCreate (char const *config, TrackerHandle *handle) |
Create a new tracker. More... | |
int | XGTrackerWorkerArgs (TrackerHandle handle, char const **args) |
Get the arguments needed for running workers. This should be called after XGTrackerRun(). More... | |
int | XGTrackerRun (TrackerHandle handle, char const *config) |
Start the tracker. The tracker runs in the background and this function returns once the tracker is started. More... | |
int | XGTrackerWaitFor (TrackerHandle handle, char const *config) |
Wait for the tracker to finish, should be called after XGTrackerRun(). This function will block until the tracker task is finished or timeout is reached. More... | |
int | XGTrackerFree (TrackerHandle handle) |
Free a tracker instance. This should be called after XGTrackerWaitFor(). If the tracker is not properly waited, this function will shutdown all connections with the tracker, potentially leading to undefined behavior. More... | |
int | XGCommunicatorInit (char const *config) |
Initialize the collective communicator. More... | |
int | XGCommunicatorFinalize (void) |
Finalize the collective communicator. More... | |
int | XGCommunicatorGetRank (void) |
Get rank of the current process. More... | |
int | XGCommunicatorGetWorldSize (void) |
Get the total number of processes. More... | |
int | XGCommunicatorIsDistributed (void) |
Get if the communicator is distributed. More... | |
int | XGCommunicatorPrint (char const *message) |
Print the message to the tracker. More... | |
int | XGCommunicatorGetProcessorName (const char **name_str) |
Get the name of the processor. More... | |
int | XGCommunicatorBroadcast (void *send_receive_buffer, size_t size, int root) |
Broadcast a memory region to all others from root. This function is NOT thread-safe. More... | |
int | XGCommunicatorAllreduce (void *send_receive_buffer, size_t count, int data_type, int op) |
Perform in-place allreduce. This function is NOT thread-safe. More... | |
Experimental support for exposing internal communicator in XGBoost.
The collective communicator in XGBoost evolved from the rabit
project of dmlc but has changed significantly since its adoption. It consists of a tracker and a set of workers. The tracker is responsible for bootstrapping the communication group and handling centralized tasks like logging. The workers are actual communicators performing collective tasks like allreduce.
To use the collective implementation, one needs to first create a tracker with corresponding parameters, then get the arguments for workers using XGTrackerWorkerArgs(). The obtained arguments can then be passed to the XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses std::thread
in C++, which has undefined behavior in a C++ destructor due to the runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the runtime is shutting down. This requirement is similar to a Python thread or socket, which should not be relied upon in a __del__
function.
Since it's used as a part of XGBoost, errors will be returned when a XGBoost function is called, for instance, training a booster might return a connection error.
typedef void* TrackerHandle |
Handle to the tracker.
There are currently two types of tracker in XGBoost, first one is rabit
, while the other one is federated
. rabit
is used for normal collective communication, while federated
is used for federated learning.
int XGCommunicatorAllreduce | ( | void * | send_receive_buffer, |
size_t | count, | ||
int | data_type, | ||
int | op | ||
) |
Perform in-place allreduce. This function is NOT thread-safe.
Example Usage: the following code gives sum of the result
send_receive_buffer | Buffer for both sending and receiving data. |
count | Number of elements to be reduced. |
data_type | Enumeration of data type, see xgboost::collective::DataType in communicator.h. |
op | Enumeration of operation type, see xgboost::collective::Operation in communicator.h. |
int XGCommunicatorBroadcast | ( | void * | send_receive_buffer, |
size_t | size, | ||
int | root | ||
) |
Broadcast a memory region to all others from root. This function is NOT thread-safe.
Example:
send_receive_buffer | Pointer to the send or receive buffer. |
size | Size of the data in bytes. |
root | The process rank to broadcast from. |
int XGCommunicatorFinalize | ( | void | ) |
Finalize the collective communicator.
Call this function after you have finished all jobs.
int XGCommunicatorGetProcessorName | ( | const char ** | name_str | ) |
Get the name of the processor.
name_str | Pointer to received returned processor name. |
int XGCommunicatorGetRank | ( | void | ) |
Get rank of the current process.
int XGCommunicatorGetWorldSize | ( | void | ) |
Get the total number of processes.
int XGCommunicatorInit | ( | char const * | config | ) |
Initialize the collective communicator.
Currently the communicator API is experimental, function signatures may change in the future without notice.
Call this once in the worker process before using anything. Please make sure XGCommunicatorFinalize() is called after use. The initialized commuicator is a global thread-local variable.
config | JSON encoded configuration. Accepted JSON keys are:
|
Only applicable to the rabit
communicator:
libnccl.so
.Only applicable to the federated
communicator (use upper case for environment variables, use lower case for runtime configuration):
int XGCommunicatorIsDistributed | ( | void | ) |
Get if the communicator is distributed.
int XGCommunicatorPrint | ( | char const * | message | ) |
Print the message to the tracker.
This function can be used to communicate the information of the progress to the user who monitors the tracker.
message | The message to be printed. |
int XGTrackerCreate | ( | char const * | config, |
TrackerHandle * | handle | ||
) |
Create a new tracker.
config | JSON encoded parameters. |
rabit
and federated
. See TrackerHandle for more info.Some configurations are rabit
specific:
rabit
tracker to specify the address of the host. This can be useful when the communicator cannot reliably obtain the host address.Some federated
specific configurations:
handle | The handle to the created tracker. |
int XGTrackerFree | ( | TrackerHandle | handle | ) |
Free a tracker instance. This should be called after XGTrackerWaitFor(). If the tracker is not properly waited, this function will shutdown all connections with the tracker, potentially leading to undefined behavior.
handle | The handle to the tracker. |
int XGTrackerRun | ( | TrackerHandle | handle, |
char const * | config | ||
) |
Start the tracker. The tracker runs in the background and this function returns once the tracker is started.
handle | The handle to the tracker. |
config | Unused at the moment, preserved for the future. |
int XGTrackerWaitFor | ( | TrackerHandle | handle, |
char const * | config | ||
) |
Wait for the tracker to finish, should be called after XGTrackerRun(). This function will block until the tracker task is finished or timeout is reached.
handle | The handle to the tracker. |
config | JSON encoded configuration. No argument is required yet, preserved for the future. |
int XGTrackerWorkerArgs | ( | TrackerHandle | handle, |
char const ** | args | ||
) |
Get the arguments needed for running workers. This should be called after XGTrackerRun().
handle | The handle to the tracker. |
args | The arguments returned as a JSON document. |