%% subject loop function decode = som_subject_batch(header,subjectdata,somOpts) addpath 'somtoolbox2'; RANDOMIZE_DATAPOINTS = 1; decode = struct; decode.decodePerformance = []; decode.rawTimeCourse = []; nSubjects = numel(subjectdata); disp(sprintf('batch processing %u subjects',nSubjects)); timeline = header.timeline; timeline.frameShiftStart = header.frameShift.frameShiftStart; timeline.frameShiftEnd = header.frameShift.frameShiftEnd; timeline.decodeDuration = header.frameShift.decodeDuration; timeLineStart = timeline.frameShiftStart; timeLineEnd = timeline.frameShiftEnd; % TimePointMatrix for subjectDataID = 1:nSubjects currentSubject = subjectdata{subjectDataID}; timePointArgs.pst = currentSubject.pst; timePointArgs.labelMap = LabelMap(header.classDef.labelCells,header.classDef.conditionCells); timePointArgs.eventList = header.classDef.eventMatrix; timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs); decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst]; display(sprintf('%u -fold cross validation for %u timeslices.\n',somOpts.nfold,size(1:timeLineEnd-timeLineStart+1,2))); for timeIndex = 1:timeLineEnd-timeLineStart+1 svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex); nElements = length(svmstruct.svmlabel); if RANDOMIZE_DATAPOINTS rndindex = randperm(nElements); svmstruct.svmdata = svmstruct.svmdata(rndindex,:); svmstruct.svmlabel = svmstruct.svmlabel(rndindex); end chunksize = nElements / somOpts.nfold; cross_value = 0; for iFold = 1:somOpts.nfold chunkstart = (iFold-1)*chunksize+1; chunkend = min(iFold*chunksize,nElements); svm_train_label = svmstruct.svmlabel; svm_train_data = svmstruct.svmdata; svm_validation_label = svmstruct.svmlabel(chunkstart:chunkend); svm_validation_data = svmstruct.svmdata(chunkstart:chunkend,:); svm_train_label(chunkstart:chunkend) = []; %del test set svm_train_data(chunkstart:chunkend,:) = [];% del test set [sD sM] = som_train(svm_train_label, svm_train_data, somOpts); performance = som_decode(sM, svm_validation_data,svm_validation_label); cross_value = [cross_value, performance]; end decode.decodePerformance = [decode.decodePerformance; cross_value]; end end display('decode done'); end