Skip to content

Commit

Permalink
Update PJRT C API version to 0.60 (#160)
Browse files Browse the repository at this point in the history
Updated PJRT C API to version 0.60, with unimplemented features.

0.55
Adds **PJRT_Buffer_Type_F8E4M3** and **PJRT_Buffer_Type_F8E3M4**.
Default for **PJRT_Buffer_Type** - unsupported.
* * *
0.56 and 0.57
Adds **PJRT_Buffer_CopyRawToHost**.
Left unimplemented.
* * *
0.58
Nothing changed in **pjrt_c_api.h**.
* * *
0.59
Adds **PJRT_Extension_Type_MemoryDescriptions**.
No changes, we don't use extensions.
* * *
0.60
Adds **PJRT_Client_CreateBuffersForAsyncHostToDevice**,
**PJRT_AsyncHostToDeviceTransferManager_TransferData** and
**PJRT_AsyncHostToDeviceTransferManager_Destroy**.
Left unimplemented.
  • Loading branch information
sdjukicTT authored Jan 13, 2025
1 parent e5bd931 commit baf4a61
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/common/pjrt_implementation/stubs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
_STUB(PJRT_Buffer_Memory);
_STUB(PJRT_Buffer_Delete);
_STUB(PJRT_Buffer_IsDeleted);
_STUB(PJRT_Buffer_CopyRawToHost);
_STUB(PJRT_Buffer_CopyToDevice);
_STUB(PJRT_Buffer_ToHostBuffer);
_STUB(PJRT_Buffer_IsOnCpu);
Expand All @@ -101,6 +102,7 @@
_STUB(PJRT_Executable_OutputDimensions);
_STUB(PJRT_Buffer_CopyToMemory);
_STUB(PJRT_Client_CreateViewOfDeviceBuffer);
_STUB(PJRT_Client_CreateBuffersForAsyncHostToDevice);
_STUB(PJRT_Executable_Fingerprint);
_STUB(PJRT_Client_TopologyDescription);
_STUB(PJRT_Buffer_DynamicDimensionIndices);
Expand All @@ -109,3 +111,5 @@
_STUB(PJRT_Memory_Kind_Id);
_STUB(PJRT_ExecuteContext_Create);
_STUB(PJRT_ExecuteContext_Destroy);
_STUB(PJRT_AsyncHostToDeviceTransferManager_Destroy);
_STUB(PJRT_AsyncHostToDeviceTransferManager_TransferData);
90 changes: 85 additions & 5 deletions third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ typedef enum {
PJRT_Extension_Type_Stream,
PJRT_Extension_Type_Layouts,
PJRT_Extension_Type_FFI,
PJRT_Extension_Type_MemoryDescriptions,
} PJRT_Extension_Type;

// PJRT_Extension_Base contains a type and a pointer to next
Expand Down Expand Up @@ -87,7 +88,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 54
#define PJRT_API_MINOR 60

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -222,9 +223,9 @@ struct PJRT_Plugin_Attributes_Args {
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes);

// Returns an array of plugin attributes which are key-value pairs. One example
// attribute is the minimum supported StableHLO version.
// TODO(b/280349977): standardize the list of attributes.
// Returns an array of plugin attributes which are key-value pairs. Common keys
// include `xla_version`, `stablehlo_current_version`, and
// `stablehlo_minimum_version`.
typedef PJRT_Error *PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args *args);

// ---------------------------------- Events -----------------------------------
Expand Down Expand Up @@ -315,11 +316,14 @@ typedef PJRT_Error *PJRT_Event_OnReady(PJRT_Event_OnReady_Args *args);
typedef struct PJRT_Client PJRT_Client;
typedef struct PJRT_Device PJRT_Device;
typedef struct PJRT_Memory PJRT_Memory;
typedef struct PJRT_ShapeSpec PJRT_ShapeSpec;
typedef struct PJRT_DeviceDescription PJRT_DeviceDescription;
typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
typedef struct PJRT_Executable PJRT_Executable;
typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable;
typedef struct PJRT_Buffer PJRT_Buffer;
typedef struct PJRT_AsyncHostToDeviceTransferManager
PJRT_AsyncHostToDeviceTransferManager;

// The caller of PJRT_Client_Create can optionally provide a key-value store
// accessible across nodes and/or processes. KV store access may be necessary to
Expand Down Expand Up @@ -600,6 +604,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_DefaultDeviceAssignment_Args,
typedef PJRT_Error *PJRT_Client_DefaultDeviceAssignment(
PJRT_Client_DefaultDeviceAssignment_Args *args);

struct PJRT_AsyncHostToDeviceTransferManager_Destroy_Args {
size_t struct_size;
PJRT_Extension_Base *extension_start;
PJRT_AsyncHostToDeviceTransferManager *transfer_manager;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_Destroy_Args,
transfer_manager);

// Frees `transfer_manager`. `transfer_manager` can be nullptr.
typedef PJRT_Error *PJRT_AsyncHostToDeviceTransferManager_Destroy(
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args *args);

struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args {
size_t struct_size;
PJRT_Extension_Base *extension_start;
PJRT_AsyncHostToDeviceTransferManager *transfer_manager;
int buffer_index;
const void *data;
int64_t offset;
int64_t transfer_size;
bool is_last_transfer;
PJRT_Event *done_with_h2d_transfer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args,
done_with_h2d_transfer);
typedef PJRT_Error *PJRT_AsyncHostToDeviceTransferManager_TransferData(
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args *args);

typedef enum {
// Invalid primitive type to serve as default.
PJRT_Buffer_Type_INVALID,
Expand Down Expand Up @@ -652,6 +685,10 @@ typedef enum {
// 2-bit integer types
PJRT_Buffer_Type_S2,
PJRT_Buffer_Type_U2,

// More truncated 8 bit floating-point formats.
PJRT_Buffer_Type_F8E4M3,
PJRT_Buffer_Type_F8E3M4,
} PJRT_Buffer_Type;

typedef enum {
Expand Down Expand Up @@ -823,6 +860,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer);
typedef PJRT_Error *PJRT_Client_CreateViewOfDeviceBuffer(
PJRT_Client_CreateViewOfDeviceBuffer_Args *args);

struct PJRT_ShapeSpec {
size_t struct_size;
PJRT_Extension_Base *extension_start;
const int64_t *dims;
size_t num_dims;
PJRT_Buffer_Type element_type;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_ShapeSpec, element_type);

struct PJRT_Client_CreateBuffersForAsyncHostToDevice_Args {
size_t struct_size;
PJRT_Extension_Base *extension_start;
PJRT_Client *client;
PJRT_ShapeSpec *shape_specs;
size_t num_shape_specs;
PJRT_Buffer_MemoryLayout **device_layouts; // optional
size_t num_device_layouts;
PJRT_Memory *memory;
PJRT_AsyncHostToDeviceTransferManager *transfer_manager; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateBuffersForAsyncHostToDevice_Args,
transfer_manager);
typedef PJRT_Error *PJRT_Client_CreateBuffersForAsyncHostToDevice(
PJRT_Client_CreateBuffersForAsyncHostToDevice_Args *args);

// -------------------------- Device Descriptions ------------------------------

// Device descriptions may be associated with an actual device
Expand Down Expand Up @@ -1762,6 +1824,20 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IsDeleted_Args, is_deleted);
// True if and only if PJRT_Buffer_Delete has previously been called.
typedef PJRT_Error *PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args *args);

struct PJRT_Buffer_CopyRawToHost_Args {
size_t struct_size;
PJRT_Extension_Base *extension_start;
PJRT_Buffer *buffer;
void *dst;
int64_t offset;
int64_t transfer_size;
PJRT_Event *event; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyRawToHost_Args, event);

typedef PJRT_Error *
PJRT_Buffer_CopyRawToHost(PJRT_Buffer_CopyRawToHost_Args *args);

struct PJRT_Buffer_CopyToDevice_Args {
size_t struct_size;
PJRT_Extension_Base *extension_start;
Expand Down Expand Up @@ -2253,11 +2329,15 @@ typedef struct PJRT_Api {

_PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Create);
_PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Destroy);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyRawToHost);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Destroy);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_TransferData);
_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateBuffersForAsyncHostToDevice);
} PJRT_Api;

enum {
PJRT_Api_STRUCT_SIZE =
PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_TopologyDescription)
PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice)
};

#undef _PJRT_API_STRUCT_FIELD
Expand Down

0 comments on commit baf4a61

Please sign in to comment.