THE AUDITORY MODELING TOOLBOX

This documentation page applies to an outdated AMT version (1.5.0). Click here for the most recent page.

View the help

Go to function

LLADO2022_TRAINNN - trains the neural network

Program code:

function [net,tr] = llado2022_trainnn(x,y,hiddenLayerSize,augmentation_ratio,SNR)
%LLADO2022_TRAINNN trains the neural network
%   Usage: [net,tr] = llado2022_trainnn(x,y,hiddenLayerSize,augmentation_ratio,SNR);
%
%   Input parameters:
%     x                  : Features of the train subset
%     y                  : Labels for training the network
%     hiddenLayerSize    : Size of the hidden layer
%     augmentation_ratio : Ratio for data augmentation stage
%     SNR                : SNR of the augmented data
%
%   Output parameters:
%     net                : trained network
%     tr                 : training history
%
%   LLADO2022_TRAINNN trains the neural network
%
%   See also: llado2022 exp_llado2022 demo_llado2022
%
%   References:
%     Lladó, Pedro, Hyvärinen, Petteri, and Pulkki, Ville. Auditory
%     model-based estimation of the effect of head-worn devices on frontal
%     horizontal localisation. Acta Acust., 6:1, 2022. [1]http ]
%     
%     References
%     
%     1. https://doi.org/10.1051/aacus/2021056
%     
%
%   Url: http://amtoolbox.org/amt-1.5.0/doc/modelstages/llado2022_trainnn.php


%   #StatusDoc: Good
%   #StatusCode: Perfect
%   #Verification: Verified
%   #Requirements: MATLAB M - Communication Systems
%   #Author: Pedro Llado (2022)
%   #Author: Petteri Hyvärinen (2022)
%   #Author: Ville Pulkki (2022)

% This file is licensed unter the GNU General Public License (GPL) either 
% version 3 of the license, or any later version as published by the Free Software 
% Foundation. Details of the GPLv3 can be found in the AMT directory "licences" and 
% at <https://www.gnu.org/licenses/gpl-3.0.html>. 
% You can redistribute this file and/or modify it under the terms of the GPLv3. 
% This file is distributed without any warranty; without even the implied warranty 
% of merchantability or fitness for a particular purpose. 
clear net;
net = fitnet(hiddenLayerSize);

net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 30/100;
net.divideParam.testRatio = 0/100;

%% Training data augmentation

Y_output_aug = y;
for aug_iter = 1:augmentation_ratio
    for id_col = 1:length(y(1,:))
        aux= y(:,id_col);
        Y_output_aug((aug_iter)*length(y(:,1))+1:(1+aug_iter)*length(y(:,1)),id_col) = aux;
    end
end

clear X_input_aug;
X_input_aug(:,:) = x(:,:);
for aug_iter = 1:augmentation_ratio
    for id_col = 1:length(x(1,:))
        %aux= awgn(x(:,id_col),SNR,'measured');
        auxnoise = randn(size(x(:,id_col)));
        aux = x(:,id_col) + scaletodbspl(auxnoise, dbspl(x(:,id_col)) - SNR);
        X_input_aug((aug_iter)*length(x(:,1))+1:(1+aug_iter)*length(x(:,1)),id_col) = aux;
    end
end

[net, tr] = train(net,X_input_aug',Y_output_aug');

end