top | item 38028478

(no title)

mrcslws | 2 years ago

Aha, I was hoping to learn about something like this, thanks for sharing. I'll try this some time. PyTorch does use different threads for the forward and backward pass, so as you suggest, setting that flag might only improve the forward pass.

discuss

order

gregjm|2 years ago

The CUDA Runtime and Driver APIs have per-thread state, so using threads would unfortunately bypass our trick here to set the flag. Assuming you're on Linux, I might suggest creating a shared library to intercept calls to the Driver API, as all Runtime functions are implemented as wrappers around Driver functions. You'd have to intercept all calls to context creation and flag setting:

  * `cuCtxCreate`

  * `cuCtxCreate_v3`

  * `cuCtxSetFlags`

  * `cuDevicePrimaryCtxRetain`

  * `cuDevicePrimaryCtxSetFlags`
... and make sure that the three least significant bits of any `flags` variable are set to `CU_CTX_SCHED_BLOCKING_SYNC`.

cuDevicePrimaryCtxSetFlags: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PR...

dlsym(3): https://man.archlinux.org/man/dlsym.3.en

ld.so(8): https://man.archlinux.org/man/ld.so.8.en