xgboost
Typedefs | Functions
Collective

Experimental support for exposing internal communicator in XGBoost. More...

Typedefs

typedef void * TrackerHandle
 Handle to the tracker. More...
 

Functions

int XGTrackerCreate (char const *config, TrackerHandle *handle)
 Create a new tracker. More...
 
int XGTrackerWorkerArgs (TrackerHandle handle, char const **args)
 Get the arguments needed for running workers. This should be called after XGTrackerRun(). More...
 
int XGTrackerRun (TrackerHandle handle, char const *config)
 Start the tracker. The tracker runs in the background and this function returns once the tracker is started. More...
 
int XGTrackerWaitFor (TrackerHandle handle, char const *config)
 Wait for the tracker to finish, should be called after XGTrackerRun(). This function will block until the tracker task is finished or timeout is reached. More...
 
int XGTrackerFree (TrackerHandle handle)
 Free a tracker instance. This should be called after XGTrackerWaitFor(). If the tracker is not properly waited, this function will shutdown all connections with the tracker, potentially leading to undefined behavior. More...
 
int XGCommunicatorInit (char const *config)
 Initialize the collective communicator. More...
 
int XGCommunicatorFinalize (void)
 Finalize the collective communicator. More...
 
int XGCommunicatorGetRank (void)
 Get rank of the current process. More...
 
int XGCommunicatorGetWorldSize (void)
 Get the total number of processes. More...
 
int XGCommunicatorIsDistributed (void)
 Get if the communicator is distributed. More...
 
int XGCommunicatorPrint (char const *message)
 Print the message to the tracker. More...
 
int XGCommunicatorGetProcessorName (const char **name_str)
 Get the name of the processor. More...
 
int XGCommunicatorBroadcast (void *send_receive_buffer, size_t size, int root)
 Broadcast a memory region to all others from root. This function is NOT thread-safe. More...
 
int XGCommunicatorAllreduce (void *send_receive_buffer, size_t count, int data_type, int op)
 Perform in-place allreduce. This function is NOT thread-safe. More...
 

Detailed Description

Experimental support for exposing internal communicator in XGBoost.

Note
This is still under development.

The collective communicator in XGBoost evolved from the rabit project of dmlc but has changed significantly since its adoption. It consists of a tracker and a set of workers. The tracker is responsible for bootstrapping the communication group and handling centralized tasks like logging. The workers are actual communicators performing collective tasks like allreduce.

To use the collective implementation, one needs to first create a tracker with corresponding parameters, then get the arguments for workers using XGTrackerWorkerArgs(). The obtained arguments can then be passed to the XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses std::thread in C++, which has undefined behavior in a C++ destructor due to the runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the runtime is shutting down. This requirement is similar to a Python thread or socket, which should not be relied upon in a __del__ function.

Since it's used as a part of XGBoost, errors will be returned when a XGBoost function is called, for instance, training a booster might return a connection error.

Typedef Documentation

◆ TrackerHandle

typedef void* TrackerHandle

Handle to the tracker.

There are currently two types of tracker in XGBoost, first one is rabit, while the other one is federated. rabit is used for normal collective communication, while federated is used for federated learning.

Function Documentation

◆ XGCommunicatorAllreduce()

int XGCommunicatorAllreduce ( void *  send_receive_buffer,
size_t  count,
int  data_type,
int  op 
)

Perform in-place allreduce. This function is NOT thread-safe.

Example Usage: the following code gives sum of the result

enum class Op {
kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
};
std::vector<int> data(10);
...
Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
...
DataType
data type accepted by xgboost interface
Definition: data.h:32
Parameters
send_receive_bufferBuffer for both sending and receiving data.
countNumber of elements to be reduced.
data_typeEnumeration of data type, see xgboost::collective::DataType in communicator.h.
opEnumeration of operation type, see xgboost::collective::Operation in communicator.h.
Returns
0 for success, -1 for failure.

◆ XGCommunicatorBroadcast()

int XGCommunicatorBroadcast ( void *  send_receive_buffer,
size_t  size,
int  root 
)

Broadcast a memory region to all others from root. This function is NOT thread-safe.

Example:

int a = 1;
Broadcast(&a, sizeof(a), root);
Parameters
send_receive_bufferPointer to the send or receive buffer.
sizeSize of the data in bytes.
rootThe process rank to broadcast from.
Returns
0 for success, -1 for failure.

◆ XGCommunicatorFinalize()

int XGCommunicatorFinalize ( void  )

Finalize the collective communicator.

Call this function after you have finished all jobs.

Returns
0 for success, -1 for failure.

◆ XGCommunicatorGetProcessorName()

int XGCommunicatorGetProcessorName ( const char **  name_str)

Get the name of the processor.

Parameters
name_strPointer to received returned processor name.
Returns
0 for success, -1 for failure.

◆ XGCommunicatorGetRank()

int XGCommunicatorGetRank ( void  )

Get rank of the current process.

Returns
Rank of the worker.

◆ XGCommunicatorGetWorldSize()

int XGCommunicatorGetWorldSize ( void  )

Get the total number of processes.

Returns
Total world size.

◆ XGCommunicatorInit()

int XGCommunicatorInit ( char const *  config)

Initialize the collective communicator.

Currently the communicator API is experimental, function signatures may change in the future without notice.

Call this once in the worker process before using anything. Please make sure XGCommunicatorFinalize() is called after use. The initialized commuicator is a global thread-local variable.

Parameters
configJSON encoded configuration. Accepted JSON keys are:
  • dmlc_communicator: The type of the communicator, this should match the tracker type.
    • rabit: Use Rabit. This is the default if the type is unspecified.
    • federated: Use the gRPC interface for Federated Learning.

Only applicable to the rabit communicator:

  • dmlc_tracker_uri: Hostname or IP address of the tracker.
  • dmlc_tracker_port: Port number of the tracker.
  • dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
  • dmlc_retry: The number of retries for connection failure.
  • dmlc_timeout: Timeout in seconds.
  • dmlc_nccl_path: Path to the nccl shared library libnccl.so.

Only applicable to the federated communicator (use upper case for environment variables, use lower case for runtime configuration):

  • federated_server_address: Address of the federated server.
  • federated_world_size: Number of federated workers.
  • federated_rank: Rank of the current worker.
  • federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
  • federated_client_key_path: Client key file path. Only needed for the SSL mode.
  • federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
Returns
0 for success, -1 for failure.

◆ XGCommunicatorIsDistributed()

int XGCommunicatorIsDistributed ( void  )

Get if the communicator is distributed.

Returns
True if the communicator is distributed.

◆ XGCommunicatorPrint()

int XGCommunicatorPrint ( char const *  message)

Print the message to the tracker.

This function can be used to communicate the information of the progress to the user who monitors the tracker.

Parameters
messageThe message to be printed.
Returns
0 for success, -1 for failure.

◆ XGTrackerCreate()

int XGTrackerCreate ( char const *  config,
TrackerHandle handle 
)

Create a new tracker.

Parameters
configJSON encoded parameters.
  • dmlc_communicator: String, the type of tracker to create. Available options are rabit and federated. See TrackerHandle for more info.
  • n_workers: Integer, the number of workers.
  • port: (Optional) Integer, the port this tracker should listen to.
  • timeout: (Optional) Integer, timeout in seconds for various networking operations. Default is 300 seconds.

Some configurations are rabit specific:

  • host: (Optional) String, Used by the the rabit tracker to specify the address of the host. This can be useful when the communicator cannot reliably obtain the host address.
  • sortby: (Optional) Integer.
    • 0: Sort workers by their host name.
    • 1: Sort workers by task IDs.

Some federated specific configurations:

  • federated_secure: Boolean, whether this is a secure server. False for testing.
  • server_key_path: Path to the server key. Used only if this is a secure server.
  • server_cert_path: Path to the server certificate. Used only if this is a secure server.
  • client_cert_path: Path to the client certificate. Used only if this is a secure server.
Parameters
handleThe handle to the created tracker.
Returns
0 for success, -1 for failure.

◆ XGTrackerFree()

int XGTrackerFree ( TrackerHandle  handle)

Free a tracker instance. This should be called after XGTrackerWaitFor(). If the tracker is not properly waited, this function will shutdown all connections with the tracker, potentially leading to undefined behavior.

Parameters
handleThe handle to the tracker.
Returns
0 for success, -1 for failure.

◆ XGTrackerRun()

int XGTrackerRun ( TrackerHandle  handle,
char const *  config 
)

Start the tracker. The tracker runs in the background and this function returns once the tracker is started.

Parameters
handleThe handle to the tracker.
configUnused at the moment, preserved for the future.
Returns
0 for success, -1 for failure.

◆ XGTrackerWaitFor()

int XGTrackerWaitFor ( TrackerHandle  handle,
char const *  config 
)

Wait for the tracker to finish, should be called after XGTrackerRun(). This function will block until the tracker task is finished or timeout is reached.

Parameters
handleThe handle to the tracker.
configJSON encoded configuration. No argument is required yet, preserved for the future.
Returns
0 for success, -1 for failure.

◆ XGTrackerWorkerArgs()

int XGTrackerWorkerArgs ( TrackerHandle  handle,
char const **  args 
)

Get the arguments needed for running workers. This should be called after XGTrackerRun().

Parameters
handleThe handle to the tracker.
argsThe arguments returned as a JSON document.
Returns
0 for success, -1 for failure.