- Code:
src/search/neural_networks
- Base class:
AbstractNetwork
AbstractNetwork
defines the interface for all networks. Every network
has to implement a function to evaluate a state and every network has to declare
which outputs it provides (e.g. heuristic value and preferred operators),
as well as, provide access to the outputs.
If you want your network to output some new information, e.g. the confidence
in its heuristic estimate, then you have to modify the AbstractNetwork
class.
Add the following four methods:
is_FEATURE()
: This method denotes if the concrete network class produced output for your feature. Its default implementation should returnfalse
.verify_FEATURE()
: This method shall terminate the executing, if the network object does not support your feature.RETURN_TYPE get_FEATURE()
: After callingevaluate(State)
, the (interpreted) network output should be internally stored and can be accessed by this method.std::vector<RETURN_TYPE> get_FEATUREs()
: After callingevaluate(vector<State>)
the interpreted output for all states in the vector can be accessed by this function. Ifevaluate(State)
was called, then this function should return a vector of size 1.
First, I suggest that you make yourself familiar with the C++ API of your framework. Write a simple example file that loads a model of your framework, feeds inputs into the model and extracts the output. Then try to compile and run it.
Once this was successful, add a new subclass of AbstractNetwork
, e.g.
FRAMEWORKNetwork
. In its constructor and or initialize method do all the
preparation of your test script. Do not forget to free any memory in the
destructor. Implement the evaluate
functions for your framework. That means,
extract the required input data from the given state(s) and store it in the
appropriate format for your framework. Feed the data to the network
and then interpret the output.
The most problematic part is adapting CMake such that your framework is
used during compilation. Please take inspiration from src/cmake_modules/Find*.cmake
.
For inspiration take a look at the ProtobufNetwork
and TorchNetwork
class.
I suggest that the input extraction and the output interpretation are done by
abstract functions. Then, you can implement many different concrete networks
for your framework.
Decide which Machine Learning Framework you want to use. Then extend the base
class for that framework, e.g. ProtobufNetwork
for Tensorflow or TorchNetwork
for PyTorch. Then implement the abstract methods of that base class. Often
those:
- convert a state to the right input format for the network
- parse the output of the network
For every output the network produces you have to overwrite:
is_OUTPUT()
: returntrue
RETURN_TYPE get_OUTPUT()
: return a single output of the right typestd::vector<RETURN_TYPE> get_OUTPUTs()
: return a vector of results for the possible vector of evaluated states.
For a simple example, take a look at the class TestTorchNetwork
.