916fe4f9 |
%% subject loop
function decode = som_subject_batch(header,subjectdata,somOpts)
addpath 'somtoolbox2';
RANDOMIZE_DATAPOINTS = 1;
decode = struct;
decode.decodePerformance = [];
decode.rawTimeCourse = [];
nSubjects = numel(subjectdata);
disp(sprintf('batch processing %u subjects',nSubjects));
timeline = header.timeline;
timeline.frameShiftStart = header.frameShift.frameShiftStart;
timeline.frameShiftEnd = header.frameShift.frameShiftEnd;
timeline.decodeDuration = header.frameShift.decodeDuration;
timeLineStart = timeline.frameShiftStart;
timeLineEnd = timeline.frameShiftEnd;
% TimePointMatrix
for subjectDataID = 1:nSubjects
currentSubject = subjectdata{subjectDataID};
timePointArgs.pst = currentSubject.pst;
timePointArgs.labelMap = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
timePointArgs.eventList = header.classDef.eventMatrix;
timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
display(sprintf('%u -fold cross validation for %u timeslices.\n',somOpts.nfold,size(1:timeLineEnd-timeLineStart+1,2)));
|
916fe4f9 |
for timeIndex = 1:timeLineEnd-timeLineStart+1
svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
nElements = length(svmstruct.svmlabel);
if RANDOMIZE_DATAPOINTS
rndindex = randperm(nElements);
svmstruct.svmdata = svmstruct.svmdata(rndindex,:);
svmstruct.svmlabel = svmstruct.svmlabel(rndindex);
end
chunksize = nElements / somOpts.nfold;
cross_value = 0;
for iFold = 1:somOpts.nfold
chunkstart = (iFold-1)*chunksize+1;
chunkend = min(iFold*chunksize,nElements);
svm_train_label = svmstruct.svmlabel;
svm_train_data = svmstruct.svmdata;
svm_validation_label = svmstruct.svmlabel(chunkstart:chunkend);
svm_validation_data = svmstruct.svmdata(chunkstart:chunkend,:);
svm_train_label(chunkstart:chunkend) = []; %del test set
svm_train_data(chunkstart:chunkend,:) = [];% del test set
[sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
performance = som_decode(sM, svm_validation_data,svm_validation_label);
cross_value = [cross_value, performance];
end
|