How to Obtain Runtime Data¶
OneFlow support oneflow.watch
and oneflow.watch_diff
. We can use them to register a callback function to get data and gradient tensor in job functions at runtime.
Using Guidance¶
To get data or gradient tensor in job function, we need to follow these steps:
-
Write a callback function and the parameters of the callback function should be annotated to indicate the data type. The logic of the callback function need to be set up by user themselves.
-
When defining the job functions, we use
oneflow.watch
oroneflow.watch_diff
to register callback function. We obtain data tensor from the former one and their corresponding gradient from the latter one. -
At the appropriate time when the job function is running, OneFlow will call the previous callback function which was registered earlier and pass the monitored data to the callback function then execute the logic in the callback function.
Take oneflow.watch
as example:
def my_watch(x: T):
#process x
@global_function()
def foo() -> T:
#define network ...
oneflow.watch(x, my_watch)
#...
The T in the code above is the data type in oneflow.typing
. Like oneflow.typing.Numpy
. Please refer to this article.
We will use the following examples to demonstrate how to use watch
and watch_diff
.
Use watch to Obtain the Data when Running¶
The following is an example to demonstrate how to use oneflow.watch
to obtain the data from middle layer in OneFlow.
Code:test_watch.py
Run above code:
python3 test_watch.py
We can get results like the followings:
in: [ 0.15727027 0.45887455 0.10939325 0.66666406 -0.62354755]
out: [0.15727027 0.45887455 0.10939325 0.66666406 0. ]
Code Explanation¶
In the example, we focus on y
in ReluJob
. Thus, we call flow.watch(y, watch_handler)
to monitor y
. The function oneflow.watch
needs two parameters:
-
The first parameter is y which we focus on.
-
The second parameter is a callback function. When OneFlow use device resources to execute
ReluJob
, it will sendy
as a parameter to callback function. We define our callback functionwatch_handler
to print out its parameters.
User can use customized callback function to process the data from OneFlow according to their own requirements.
Use watch_diff to Obtain Gradient when Running¶
The following is an example to demonstrate how to use oneflow.watch_diff
to obtain the gradient at runtime.
Code: test_watch_diff.py
Run above code:
python3 test_watch.py
[ ...
[ 1.39966095e-03 3.49164731e-03 3.31605263e-02 4.50417027e-03
7.73609674e-04 4.89911772e-02 2.47627571e-02 7.65468649e-05
-1.18361652e-01 1.20161276e-03]] (100, 10) float32
Code Explanation¶
In the example above, we use oneflow.watch_diff
to obtain the gradient. The processe is the same as the example which using oneflow.watch
to obtain data tensor.
First, we define the callback function:
def watch_diff_handler(blob: tp.Numpy):
print("watch_diff_handler:", blob, blob.shape, blob.dtype)
Then we use oneflow.watch_diff
to register the callback function in job function:
flow.watch_diff(logits, watch_diff_handler)
When running, OneFlow framework will call watch_diff_handler
and send the gradient corresponding with logits
to watch_diff_handler
.