-
-
Notifications
You must be signed in to change notification settings - Fork 196
394 convert misc nn classes to num power #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
apphp
wants to merge
6
commits into
RubixML:3.0
Choose a base branch
from
apphp:394-convert-misc-nn-classes-to-NumPower
base: 3.0
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
106c070
ML-394 Converted FeedForward to use NumPower
apphp e118309
ML-394 Converted Network to use NumPower
apphp 5dfc354
ML-394 Converted Snapshot to use NumPower
apphp 0e815da
ML-394 Improved FeedForwardTest
apphp a261d69
ML-394 Fixed copilot style comments, added prevention division by zer…
apphp 97c2173
Merge branch '3.0' into 394-convert-misc-nn-classes-to-NumPower
apphp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| <?php | ||
|
|
||
| namespace Rubix\ML\NeuralNet\FeedForwards; | ||
|
|
||
| use NDArray; | ||
| use NumPower; | ||
| use Rubix\ML\NeuralNet\Layers\Base\Contracts\Hidden; | ||
| use Rubix\ML\NeuralNet\Layers\Base\Contracts\Input; | ||
| use Rubix\ML\NeuralNet\Layers\Base\Contracts\Layer; | ||
| use Rubix\ML\NeuralNet\Layers\Base\Contracts\Output; | ||
| use Rubix\ML\NeuralNet\Layers\Base\Contracts\Parametric; | ||
| use Rubix\ML\Encoding; | ||
| use Rubix\ML\Datasets\Dataset; | ||
| use Rubix\ML\Datasets\Labeled; | ||
| use Rubix\ML\NeuralNet\Networks\Network; | ||
| use Rubix\ML\NeuralNet\Optimizers\Base\Adaptive; | ||
| use Rubix\ML\NeuralNet\Optimizers\Base\Optimizer; | ||
| use Traversable; | ||
|
|
||
| use function array_reverse; | ||
|
|
||
| /** | ||
| * Feed Forward | ||
| * | ||
| * A feed forward neural network implementation consisting of an input and | ||
| * output layer and any number of intermediate hidden layers. | ||
| * | ||
| * @internal | ||
| * | ||
| * @category Machine Learning | ||
| * @package Rubix/ML | ||
| * @author Andrew DalPino | ||
| * @author Samuel Akopyan <leumas.a@gmail.com> | ||
| */ | ||
| class FeedForward extends Network | ||
| { | ||
| /** | ||
| * The input layer to the network. | ||
| * | ||
| * @var Input | ||
| */ | ||
| protected Input $input; | ||
|
|
||
| /** | ||
| * The hidden layers of the network. | ||
| * | ||
| * @var list<Hidden> | ||
| */ | ||
| protected array $hidden = [ | ||
| // | ||
| ]; | ||
|
|
||
| /** | ||
| * The pathing of the backward pass through the hidden layers. | ||
| * | ||
| * @var list<Hidden> | ||
| */ | ||
| protected array $backPass = [ | ||
| // | ||
| ]; | ||
|
|
||
| /** | ||
| * The output layer of the network. | ||
| * | ||
| * @var Output | ||
| */ | ||
| protected Output $output; | ||
|
|
||
| /** | ||
| * The gradient descent optimizer used to train the network. | ||
| * | ||
| * @var Optimizer | ||
| */ | ||
| protected Optimizer $optimizer; | ||
|
|
||
| /** | ||
| * @param Input $input | ||
| * @param Hidden[] $hidden | ||
| * @param Output $output | ||
| * @param Optimizer $optimizer | ||
| */ | ||
| public function __construct(Input $input, array $hidden, Output $output, Optimizer $optimizer) | ||
| { | ||
| $hidden = array_values($hidden); | ||
|
|
||
| $backPass = array_reverse($hidden); | ||
|
|
||
| $this->input = $input; | ||
| $this->hidden = $hidden; | ||
| $this->output = $output; | ||
| $this->optimizer = $optimizer; | ||
| $this->backPass = $backPass; | ||
| } | ||
|
|
||
| /** | ||
| * Return the input layer. | ||
| * | ||
| * @return Input | ||
| */ | ||
| public function input() : Input | ||
| { | ||
| return $this->input; | ||
| } | ||
|
|
||
| /** | ||
| * Return an array of hidden layers indexed left to right. | ||
| * | ||
| * @return list<Hidden> | ||
| */ | ||
| public function hidden() : array | ||
| { | ||
| return $this->hidden; | ||
| } | ||
|
|
||
| /** | ||
| * Return the output layer. | ||
| * | ||
| * @return Output | ||
| */ | ||
| public function output() : Output | ||
| { | ||
| return $this->output; | ||
| } | ||
|
|
||
| /** | ||
| * Return all the layers in the network. | ||
| * | ||
| * @return Traversable<Layer> | ||
| */ | ||
| public function layers() : Traversable | ||
| { | ||
| yield $this->input; | ||
|
|
||
| yield from $this->hidden; | ||
|
|
||
| yield $this->output; | ||
| } | ||
|
|
||
| /** | ||
| * Return the number of trainable parameters in the network. | ||
| * | ||
| * @return int | ||
| */ | ||
| public function numParams() : int | ||
| { | ||
| $numParams = 0; | ||
|
|
||
| foreach ($this->layers() as $layer) { | ||
| if ($layer instanceof Parametric) { | ||
| foreach ($layer->parameters() as $parameter) { | ||
| $numParams += $parameter->param()->size(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return $numParams; | ||
| } | ||
|
|
||
| /** | ||
| * Initialize the parameters of the layers and warm the optimizer cache. | ||
| */ | ||
| public function initialize() : void | ||
| { | ||
| $fanIn = 1; | ||
|
|
||
| foreach ($this->layers() as $layer) { | ||
| $fanIn = $layer->initialize($fanIn); | ||
| } | ||
|
|
||
| if ($this->optimizer instanceof Adaptive) { | ||
| foreach ($this->layers() as $layer) { | ||
| if ($layer instanceof Parametric) { | ||
| foreach ($layer->parameters() as $param) { | ||
| $this->optimizer->warm($param); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Run an inference pass and return the activations at the output layer. | ||
| * | ||
| * @param Dataset $dataset | ||
| * @return NDArray | ||
| */ | ||
| public function infer(Dataset $dataset) : NDArray | ||
| { | ||
| $input = NumPower::transpose(NumPower::array($dataset->samples()), [1, 0]); | ||
|
|
||
| foreach ($this->layers() as $layer) { | ||
| $input = $layer->infer($input); | ||
| } | ||
|
|
||
| return NumPower::transpose($input, [1, 0]); | ||
| } | ||
|
|
||
| /** | ||
| * Perform a forward and backward pass of the network in one call. Returns | ||
| * the loss from the backward pass. | ||
| * | ||
| * @param Labeled $dataset | ||
| * @return float | ||
| */ | ||
| public function roundtrip(Labeled $dataset) : float | ||
| { | ||
| $input = NumPower::transpose(NumPower::array($dataset->samples()), [1, 0]); | ||
|
|
||
| $this->feed($input); | ||
|
|
||
| $loss = $this->backpropagate($dataset->labels()); | ||
|
|
||
| return $loss; | ||
| } | ||
|
|
||
| /** | ||
| * Feed a batch through the network and return a matrix of activations at the output later. | ||
| * | ||
| * @param NDArray $input | ||
| * @return NDArray | ||
| */ | ||
| public function feed(NDArray $input) : NDArray | ||
| { | ||
| foreach ($this->layers() as $layer) { | ||
| $input = $layer->forward($input); | ||
| } | ||
|
|
||
| return $input; | ||
| } | ||
|
|
||
| /** | ||
| * Backpropagate the gradient of the cost function and return the loss. | ||
| * | ||
| * @param list<string|int|float> $labels | ||
| * @return float | ||
| */ | ||
| public function backpropagate(array $labels) : float | ||
| { | ||
| [$gradient, $loss] = $this->output->back($labels, $this->optimizer); | ||
|
|
||
| foreach ($this->backPass as $layer) { | ||
| $gradient = $layer->back($gradient, $this->optimizer); | ||
| } | ||
|
|
||
| return $loss; | ||
| } | ||
|
|
||
| /** | ||
| * Export the network architecture as a graph in dot format. | ||
| * | ||
| * @return Encoding | ||
| */ | ||
| public function exportGraphviz() : Encoding | ||
| { | ||
| $dot = 'digraph Tree {' . PHP_EOL; | ||
| $dot .= ' node [shape=box, fontname=helvetica];' . PHP_EOL; | ||
|
|
||
| $layerNum = 0; | ||
|
|
||
| foreach ($this->layers() as $layer) { | ||
| ++$layerNum; | ||
|
|
||
| $dot .= " N$layerNum [label=\"$layer\",style=\"rounded\"]" . PHP_EOL; | ||
|
|
||
| if ($layerNum > 1) { | ||
| $parentId = $layerNum - 1; | ||
|
|
||
| $dot .= " N{$parentId} -> N{$layerNum};" . PHP_EOL; | ||
| } | ||
| } | ||
|
|
||
| $dot .= '}'; | ||
|
|
||
| return new Encoding($dot); | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| use Stringable; | ||
|
|
||
| /** | ||
| * Hidden | ||
| * Layer | ||
| * | ||
| * @category Machine Learning | ||
| * @package Rubix/ML | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.