private/som_subject_batch.m
916fe4f9
 %% subject loop
 function decode = som_subject_batch(header,subjectdata,somOpts)
 
 addpath 'somtoolbox2';
 
 RANDOMIZE_DATAPOINTS = 1;
 
 decode = struct;
 decode.decodePerformance = [];
 decode.rawTimeCourse     = [];
 
 nSubjects = numel(subjectdata);
 disp(sprintf('batch processing %u subjects',nSubjects));
 
 timeline = header.timeline;
 timeline.frameShiftStart = header.frameShift.frameShiftStart;
 timeline.frameShiftEnd   = header.frameShift.frameShiftEnd;
 timeline.decodeDuration  = header.frameShift.decodeDuration;
 
 
 timeLineStart   = timeline.frameShiftStart;
 timeLineEnd     = timeline.frameShiftEnd;
 
 % TimePointMatrix
 for subjectDataID = 1:nSubjects
     currentSubject = subjectdata{subjectDataID};
     timePointArgs.pst           = currentSubject.pst;
     timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
     timePointArgs.eventList     = header.classDef.eventMatrix;
 
     timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
 
     decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
 
 
     display(sprintf('%u -fold cross validation for %u timeslices.\n',somOpts.nfold,size(1:timeLineEnd-timeLineStart+1,2)));
 
0c4a44fe
     decode_timeline = [];
916fe4f9
     for timeIndex = 1:timeLineEnd-timeLineStart+1
         svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
         nElements = length(svmstruct.svmlabel);
 
         if RANDOMIZE_DATAPOINTS
             rndindex  = randperm(nElements);
             svmstruct.svmdata   = svmstruct.svmdata(rndindex,:);
             svmstruct.svmlabel  = svmstruct.svmlabel(rndindex);
         end
 
         chunksize = nElements / somOpts.nfold;
         
         cross_value = 0;
         for iFold = 1:somOpts.nfold
             chunkstart = (iFold-1)*chunksize+1;
             chunkend   = min(iFold*chunksize,nElements);
 
             svm_train_label = svmstruct.svmlabel;
             svm_train_data  = svmstruct.svmdata;
 
             svm_validation_label = svmstruct.svmlabel(chunkstart:chunkend);
             svm_validation_data  = svmstruct.svmdata(chunkstart:chunkend,:);
 
             svm_train_label(chunkstart:chunkend) = []; %del test set
             svm_train_data(chunkstart:chunkend,:) = [];% del test set
 
             [sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
 
             performance = som_decode(sM, svm_validation_data,svm_validation_label);
 
             cross_value = [cross_value, performance];
         end
0c4a44fe
         decode_timeline = [decode_timeline; mean(cross_value)];
916fe4f9
     end
0c4a44fe
     decode.decodePerformance = [decode.decodePerformance decode_timeline];
916fe4f9
 end
 display('decode done');
 end