summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mv_machine_learning/face_recognition/include/face_recognition.h4
-rw-r--r--mv_machine_learning/face_recognition/include/nntrainer_dsm.h4
-rw-r--r--mv_machine_learning/face_recognition/include/nntrainer_fvm.h5
-rw-r--r--mv_machine_learning/face_recognition/src/face_recognition.cpp156
-rw-r--r--mv_machine_learning/face_recognition/src/nntrainer_dsm.cpp50
-rw-r--r--mv_machine_learning/face_recognition/src/nntrainer_fvm.cpp15
-rw-r--r--mv_machine_learning/training/include/data_set_manager.h9
-rw-r--r--mv_machine_learning/training/include/feature_vector_manager.h7
-rw-r--r--mv_machine_learning/training/include/training_model.h1
-rw-r--r--mv_machine_learning/training/src/data_set_manager.cpp6
-rw-r--r--mv_machine_learning/training/src/training_model.cpp10
11 files changed, 140 insertions, 127 deletions
diff --git a/mv_machine_learning/face_recognition/include/face_recognition.h b/mv_machine_learning/face_recognition/include/face_recognition.h
index 377e359b..0a4620ac 100644
--- a/mv_machine_learning/face_recognition/include/face_recognition.h
+++ b/mv_machine_learning/face_recognition/include/face_recognition.h
@@ -102,9 +102,7 @@ private:
std::unique_ptr<DataSetManager> CreateDSM(const mv_inference_backend_type_e backend_type);
std::unique_ptr<FeatureVectorManager> CreateFVM(const mv_inference_backend_type_e backend_type,
std::string file_name);
- void UpdateDataSet(std::unique_ptr<DataSetManager> &data_set, std::vector<float> &feature_vec, const int label_idx,
- const int label_cnt);
- void UpdateDataSet(std::unique_ptr<DataSetManager> &data_set);
+ void StoreDataSet(std::unique_ptr<DataSetManager> &data_set, unsigned int label_cnt);
int GetAnswer();
std::vector<model_layer_info> &GetBackboneInputLayerInfo();
int GetVecFromMvSource(mv_source_h img_src, std::vector<float> &out_vec);
diff --git a/mv_machine_learning/face_recognition/include/nntrainer_dsm.h b/mv_machine_learning/face_recognition/include/nntrainer_dsm.h
index a2ece919..20eac475 100644
--- a/mv_machine_learning/face_recognition/include/nntrainer_dsm.h
+++ b/mv_machine_learning/face_recognition/include/nntrainer_dsm.h
@@ -31,7 +31,9 @@ public:
NNTrainerDSM();
~NNTrainerDSM() = default;
- void LoadDataSet(const std::string file_name) override;
+ void LoadDataSet(const std::string file_name, unsigned int new_label_cnt) override;
+ void AddDataSet(std::vector<float> &feature_vec, const unsigned int label_idx,
+ const unsigned int label_cnt) override;
};
#endif \ No newline at end of file
diff --git a/mv_machine_learning/face_recognition/include/nntrainer_fvm.h b/mv_machine_learning/face_recognition/include/nntrainer_fvm.h
index c410fa39..2be41ed6 100644
--- a/mv_machine_learning/face_recognition/include/nntrainer_fvm.h
+++ b/mv_machine_learning/face_recognition/include/nntrainer_fvm.h
@@ -29,9 +29,10 @@ public:
NNTrainerFVM(const std::string feature_vector_file = "feature_vector_file.dat");
~NNTrainerFVM() = default;
- void WriteHeader(size_t feature_size, size_t one_hot_table_size, unsigned int data_set_cnt) override;
+ void WriteHeader(size_t feature_size, size_t label_cnt, unsigned int data_set_cnt) override;
void ReadHeader(FeaVecHeader &header) override;
- void WriteFeatureVec(std::vector<float> &feature_vec, const int max_label, const int label_index) override;
+ void StoreData(std::vector<std::vector<float> > &features_vec,
+ std::vector<unsigned int> &label_index) override;
void Remove() override;
};
diff --git a/mv_machine_learning/face_recognition/src/face_recognition.cpp b/mv_machine_learning/face_recognition/src/face_recognition.cpp
index 440408e4..8447a095 100644
--- a/mv_machine_learning/face_recognition/src/face_recognition.cpp
+++ b/mv_machine_learning/face_recognition/src/face_recognition.cpp
@@ -112,11 +112,8 @@ unique_ptr<FeatureVectorManager> FaceRecognition::CreateFVM(const mv_inference_b
throw InvalidParameter("Invalid training engine backend type.");
}
-void FaceRecognition::UpdateDataSet(unique_ptr<DataSetManager> &data_set, vector<float> &feature_vec,
- const int label_idx, const int label_cnt)
+void FaceRecognition::StoreDataSet(unique_ptr<DataSetManager> &data_set, unsigned int label_cnt)
{
- size_t data_set_cnt = 0;
-
try {
auto fvm = CreateFVM(_config.training_engine_backend_type, _config.feature_vector_file_path);
auto fvm_new = CreateFVM(_config.training_engine_backend_type, _config.feature_vector_file_path + ".new");
@@ -124,40 +121,11 @@ void FaceRecognition::UpdateDataSet(unique_ptr<DataSetManager> &data_set, vector
// Make sure feature vector file.
CheckFeatureVectorFile(fvm, fvm_new);
- data_set = CreateDSM(_config.training_engine_backend_type);
-
- // 1. If data set file exists then load the file to DataSetManager object first
- // and then write them to the data set file with updated label value, and then
- // write a new dataset to the data set file.
- // Otherwise, it writes only new data set to the data set file.
- if (FaceRecogUtil::IsFileExist(fvm->GetFileName())) {
- data_set->LoadDataSet(fvm->GetFileName());
-
- vector<vector<float> > feature_vectors = data_set->GetData();
- vector<unsigned int> label_idx_vectors = data_set->GetLabelIdx();
-
- // 1) Write existing feature vectors and its one-hot encoding table considered
- // for new label count to the data set file.
- for (unsigned int idx = 0; idx < feature_vectors.size(); ++idx)
- fvm_new->WriteFeatureVec(feature_vectors[idx], label_cnt, label_idx_vectors[idx]);
-
- data_set_cnt += feature_vectors.size();
-
- // 2) If same feature vector isn't duplicated then write the feature vector to data set file.
- if (!data_set->IsFeatureVectorDuplicated(feature_vec)) {
- fvm_new->WriteFeatureVec(feature_vec, label_cnt, label_idx);
- LOGD("Added a new feature vector to data set file.");
- data_set_cnt++;
- }
- } else {
- // 1) Write only a new data set to the data st file.
- fvm_new->WriteFeatureVec(feature_vec, label_cnt, label_idx);
- LOGD("Added a new feature vector to data set file.");
- data_set_cnt++;
- }
+ // 1. Write feature vector and it's label index.
+ fvm_new->StoreData(data_set->GetData(), data_set->GetLabelIdx());
// 2. Write feature vector header.
- fvm_new->WriteHeader(feature_vec.size(), label_cnt, data_set_cnt);
+ fvm_new->WriteHeader(data_set->GetFeaVecSize(), label_cnt, data_set->GetData().size());
int ret = 0;
@@ -171,31 +139,12 @@ void FaceRecognition::UpdateDataSet(unique_ptr<DataSetManager> &data_set, vector
ret = ::rename(fvm_new->GetFileName().c_str(), fvm->GetFileName().c_str());
if (ret)
throw InvalidOperation("Fail to rename new feature vector file to original one.");
-
- data_set->Clear();
- data_set->LoadDataSet(fvm->GetFileName());
} catch (const BaseException &e) {
LOGE("%s", e.what());
throw e;
}
}
-void FaceRecognition::UpdateDataSet(unique_ptr<DataSetManager> &data_set)
-{
- try {
- data_set = CreateDSM(_config.training_engine_backend_type);
-
- auto fvm = CreateFVM(_config.training_engine_backend_type, _config.feature_vector_file_path);
-
- if (FaceRecogUtil::IsFileExist(fvm->GetFileName()) == false)
- throw InvalidOperation("Feature vector file not found.");
-
- data_set->LoadDataSet(fvm->GetFileName());
- } catch (const BaseException &e) {
- LOGE("%s", e.what());
- throw e;
- }
-}
void FaceRecognition::SetConfig(FaceRecognitionConfig &config)
{
@@ -360,18 +309,35 @@ int FaceRecognition::RegisterNewFace(mv_source_h img_src, string label_name)
copy(buffer, buffer + backbone_output_buffer->size / sizeof(float), back_inserter(feature_vec));
// Get label index and count.
- int label_idx = _label_manager->GetLabelIndex(label_name);
- int label_cnt = _label_manager->GetMaxLabel();
+ unsigned int label_idx = _label_manager->GetLabelIndex(label_name);
+ unsigned int label_cnt = _label_manager->GetMaxLabel();
_training_model->ConfigureModel(label_cnt);
- unique_ptr<DataSetManager> data_set;
+ unique_ptr<DataSetManager> data_set = CreateDSM(_config.training_engine_backend_type);
+
+ data_set->Clear();
+
+ // Load existing feature vectors if the feature vector file exists.
+ if (FaceRecogUtil::IsFileExist(_config.feature_vector_file_path) == true) {
+ LOGI("feature vector file already exists so it loads the file first.");
+ data_set->LoadDataSet(_config.feature_vector_file_path, label_cnt);
+ }
+
+ // Add new feature vectors.
+ data_set->AddDataSet(feature_vec, label_idx, label_cnt);
- UpdateDataSet(data_set, feature_vec, label_idx, label_cnt);
_training_model->ApplyDataSet(data_set);
_training_model->Compile();
_training_model->Train();
+ // TODO. apply feature vector priority policy here.
+ // We can get weight trained from NNTrainer.
+ // _training_model->getWeights(&weights, &size, "centroid_knn1");
+
+ // Store dataset to feature vector file.
+ StoreDataSet(data_set, label_cnt);
+
// label_cnt can be changed every time the training is performed and all data set will be used for the training
// again in this case. So make sure to clear previous data set before next training.
_training_model->ClearDataSet(data_set);
@@ -524,7 +490,6 @@ int FaceRecognition::RecognizeFace(mv_source_h img_src)
_result.raw_data.clear();
copy(raw_buffer, raw_buffer + internal_output_buffer->size / sizeof(float), back_inserter(_result.raw_data));
-
_status = INFERENCED;
return GetAnswer();
@@ -556,6 +521,8 @@ int FaceRecognition::DeleteLabel(string label_name)
unsigned int target_label_idx = _label_manager->GetLabelIndex(label_name);
+ auto label_cnt_ori = _label_manager->GetMaxLabel();
+
// Get label count after removing a given label from the label file.
_label_manager->RemoveLabel(label_name);
@@ -568,17 +535,21 @@ int FaceRecognition::DeleteLabel(string label_name)
auto data_set = CreateDSM(_config.training_engine_backend_type);
- data_set->LoadDataSet(fvm->GetFileName());
+ // feature vectors corresponding to given label aren't removed yet from feature vector file.
+ // So label_cnt_ori is needed.
+ data_set->LoadDataSet(fvm->GetFileName(), label_cnt_ori);
- vector<vector<float> > feature_vectors = data_set->GetData();
- vector<unsigned int> label_idx_vectors = data_set->GetLabelIdx();
+ vector<vector<float> > feature_vectors_old = data_set->GetData();
+ vector<unsigned int> label_idx_vectors_old = data_set->GetLabelIdx();
+ vector<vector<float> > feature_vectors_new;
+ vector<unsigned int> label_idx_vectors_new;
size_t data_set_cnt = 0;
// Write existing feature vectors and its one-hot encoding table with updated label.
- for (unsigned int idx = 0; idx < feature_vectors.size(); ++idx) {
+ for (unsigned int idx = 0; idx < feature_vectors_old.size(); ++idx) {
// Except the data sets with a given target_label_idx.
- if (label_idx_vectors[idx] == target_label_idx)
+ if (label_idx_vectors_old[idx] == target_label_idx)
continue;
// One-hot encoding table should be updated.
@@ -596,43 +567,52 @@ int FaceRecognition::DeleteLabel(string label_name)
// offset 1 : label 3
//
// So if the index of removed label less than remaining index then decrease each index.
- if (label_idx_vectors[idx] > target_label_idx)
- label_idx_vectors[idx]--;
+ if (label_idx_vectors_old[idx] > target_label_idx)
+ label_idx_vectors_old[idx]--;
- fvm_new->WriteFeatureVec(feature_vectors[idx], label_cnt, label_idx_vectors[idx]);
+ feature_vectors_new.push_back(feature_vectors_old[idx]);
+ label_idx_vectors_new.push_back(label_idx_vectors_old[idx]);
data_set_cnt++;
}
- fvm_new->WriteHeader(feature_vectors[0].size(), label_cnt, data_set_cnt);
+ // Retrain only in case that feature vectors exist.
+ if (data_set_cnt > 0) {
+ fvm_new->StoreData(feature_vectors_new, label_idx_vectors_new);
+ fvm_new->WriteHeader(feature_vectors_new[0].size(), label_cnt, data_set_cnt);
- int ret = 0;
-
- if (FaceRecogUtil::IsFileExist(fvm->GetFileName())) {
- // Change new data file to existing one.
- ret = ::remove(fvm->GetFileName().c_str());
- if (ret)
- throw InvalidOperation("Fail to remove feature vector file.");
- }
+ int ret = 0;
- ret = ::rename(fvm_new->GetFileName().c_str(), fvm->GetFileName().c_str());
- if (ret)
- throw InvalidOperation("Fail to rename new feature vector file to original one.");
+ if (FaceRecogUtil::IsFileExist(fvm->GetFileName())) {
+ // Change new data file to existing one.
+ ret = ::remove(fvm->GetFileName().c_str());
+ if (ret)
+ throw InvalidOperation("Fail to remove feature vector file.");
+ }
- if (data_set_cnt == 0) {
- _training_model->RemoveModel();
- fvm->Remove();
- _label_manager->Remove();
+ ret = ::rename(fvm_new->GetFileName().c_str(), fvm->GetFileName().c_str());
+ if (ret)
+ throw InvalidOperation("Fail to rename new feature vector file to original one.");
- LOGD("No training data so removed all relevant files.");
- } else {
_training_model->ConfigureModel(label_cnt);
+ unique_ptr<DataSetManager> new_data_set = CreateDSM(_config.training_engine_backend_type);
+ new_data_set->Clear();
- unique_ptr<DataSetManager> new_data_set;
+ // TODO. Remove existing internal model file.
- UpdateDataSet(new_data_set);
+ new_data_set->LoadDataSet(_config.feature_vector_file_path, label_cnt);
_training_model->ApplyDataSet(new_data_set);
_training_model->Compile();
_training_model->Train();
+
+ // TODO. apply feature vector priority policy here.
+ // We can get weight trained from NNTrainer.
+ // _training_model->getWeights(&weights, &size, "centroid_knn1");
+ } else {
+ _training_model->RemoveModel();
+ fvm->Remove();
+ _label_manager->Remove();
+
+ LOGD("No training data so removed all relevant files.");
}
_status = DELETED;
diff --git a/mv_machine_learning/face_recognition/src/nntrainer_dsm.cpp b/mv_machine_learning/face_recognition/src/nntrainer_dsm.cpp
index b732f685..19896a2e 100644
--- a/mv_machine_learning/face_recognition/src/nntrainer_dsm.cpp
+++ b/mv_machine_learning/face_recognition/src/nntrainer_dsm.cpp
@@ -28,14 +28,30 @@ void NNTrainerDSM::PrintHeader(FeaVecHeader &fvh)
{
LOGD("signature = %u", fvh.signature);
LOGD("feature vector size = %zu", fvh.feature_size);
- LOGD("one hot encoding table size = %zu", fvh.one_hot_table_size);
+ LOGD("label count = %zu", fvh.label_cnt);
LOGD("data set count = %u", fvh.data_set_cnt);
}
NNTrainerDSM::NNTrainerDSM() : DataSetManager()
{}
-void NNTrainerDSM::LoadDataSet(const string file_name)
+void NNTrainerDSM::AddDataSet(std::vector<float> &feature_vec, const unsigned int label_idx,
+ const unsigned int label_cnt)
+{
+ _data.push_back(feature_vec);
+ _label_index.push_back(label_idx);
+
+ vector<float> oneHotEncoding;
+
+ for (size_t num = 0; num < label_cnt; ++num)
+ oneHotEncoding.push_back(label_idx == num ? 1.0f : 0.0f);
+
+ _labels.push_back(oneHotEncoding);
+ _feature_vector_size = feature_vec.size();
+ _label_count = label_cnt;
+}
+
+void NNTrainerDSM::LoadDataSet(const string file_name, unsigned int new_label_cnt)
{
std::ifstream inFile(file_name);
@@ -59,30 +75,38 @@ void NNTrainerDSM::LoadDataSet(const string file_name)
if (FeatureVectorManager::feature_vector_signature != fvh.signature)
throw InvalidOperation("Wrong feature vector header.");
- size_t line_size_in_bytes = fvh.feature_size * sizeof(float) + fvh.one_hot_table_size * sizeof(float);
+ /*
+ * stride line format is as follows
+ * ********************************
+ * ____________________________
+ * |feature vector|label index|
+ * ----------------------------
+ */
+ size_t line_size_in_bytes = fvh.feature_size * sizeof(float) + sizeof(float);
_feature_vector_size = fvh.feature_size;
- _label_size = fvh.one_hot_table_size;
+ _label_count = fvh.label_cnt;
_data_set_length = line_size_in_bytes;
- vector<float> line_data(fvh.feature_size + fvh.one_hot_table_size);
+ vector<float> line_data(fvh.feature_size + 1);
for (size_t idx = 0; idx < fvh.data_set_cnt; ++idx) {
- inFile.read((char *) line_data.data(), line_size_in_bytes);
+ inFile.read((char *)line_data.data(), line_size_in_bytes);
vector<float> data;
+
copy_n(line_data.begin(), _feature_vector_size, back_inserter(data));
_data.push_back(data);
- int label_idx = 0;
- vector<float> label;
+ unsigned int label_idx;
- for (size_t num = 0; num < fvh.one_hot_table_size; ++num) {
- if (line_data[fvh.feature_size + num] == 1.0f)
- label_idx = num;
+ memcpy(&label_idx, (void*)(line_data.data() + _feature_vector_size), 4);
+
+ vector<float> label;
- label.push_back((float) line_data[fvh.feature_size + num]);
- }
+ // max label count may be changed so update one hot encoding table.
+ for (size_t num = 0; num < new_label_cnt; ++num)
+ label.push_back(label_idx == num ? 1.0f : 0.0f);
_labels.push_back(label);
_label_index.push_back(label_idx);
diff --git a/mv_machine_learning/face_recognition/src/nntrainer_fvm.cpp b/mv_machine_learning/face_recognition/src/nntrainer_fvm.cpp
index 24f4ad23..8eb7b7bf 100644
--- a/mv_machine_learning/face_recognition/src/nntrainer_fvm.cpp
+++ b/mv_machine_learning/face_recognition/src/nntrainer_fvm.cpp
@@ -25,14 +25,14 @@ using namespace mediavision::machine_learning::exception;
NNTrainerFVM::NNTrainerFVM(const string feature_vector_file) : FeatureVectorManager(feature_vector_file)
{}
-void NNTrainerFVM::WriteHeader(size_t feature_size, size_t one_hot_table_size, unsigned int data_set_cnt)
+void NNTrainerFVM::WriteHeader(size_t feature_size, size_t label_cnt, unsigned int data_set_cnt)
{
ofstream outFile { _feature_vector_file, ios::out | ios::binary | ios::app };
if (!outFile.is_open())
throw InvalidOperation("fail to open a file");
- FeaVecHeader fvHeader { FeatureVectorManager::feature_vector_signature, feature_size, one_hot_table_size,
+ FeaVecHeader fvHeader { FeatureVectorManager::feature_vector_signature, feature_size, label_cnt,
data_set_cnt };
outFile.write((char *) &fvHeader, sizeof(FeaVecHeader));
@@ -55,18 +55,17 @@ void NNTrainerFVM::ReadHeader(FeaVecHeader &header)
throw InvalidParameter("wrong feature vector file header.");
}
-void NNTrainerFVM::WriteFeatureVec(vector<float> &feature_vec, const int max_label, const int label_index)
+void NNTrainerFVM::StoreData(vector<vector<float> > &features_vec,
+ vector<unsigned int> &label_index)
{
ofstream outFile { _feature_vector_file, ios::out | ios::binary | ios::app };
if (!outFile.is_open())
throw InvalidOperation("fail to open a file.");
- outFile.write(reinterpret_cast<char *>(feature_vec.data()), feature_vec.size() * sizeof(float));
-
- for (int idx = 0; idx < max_label; ++idx) {
- float oneHotTable = (label_index == idx) ? 1.0f : 0.0f;
- outFile.write((char *) &oneHotTable, sizeof(float));
+ for (size_t idx = 0; idx < features_vec.size(); ++idx) {
+ outFile.write(reinterpret_cast<char *>(features_vec[idx].data()), features_vec[idx].size() * sizeof(float));
+ outFile.write(reinterpret_cast<char *>(&label_index[idx]), sizeof(unsigned int));
}
}
diff --git a/mv_machine_learning/training/include/data_set_manager.h b/mv_machine_learning/training/include/data_set_manager.h
index 6e86b3b8..a4fa7eca 100644
--- a/mv_machine_learning/training/include/data_set_manager.h
+++ b/mv_machine_learning/training/include/data_set_manager.h
@@ -17,7 +17,6 @@
#ifndef __DATA_SET_MANAGER_H__
#define __DATA_SET_MANAGER_H__
-#include <iostream>
#include <fstream>
#include <vector>
@@ -30,7 +29,7 @@ protected:
std::vector<std::vector<float> > _labels;
std::vector<unsigned int> _label_index;
size_t _feature_vector_size;
- size_t _label_size;
+ size_t _label_count;
size_t _data_set_length;
public:
@@ -42,11 +41,13 @@ public:
std::vector<std::vector<float> > &GetData(void);
std::vector<std::vector<float> > &GetLabel(void);
size_t GetFeaVecSize(void);
- size_t GetLabelSize(void);
+ size_t GetLabelCnt(void);
size_t GetDataSetLen(void);
std::vector<unsigned int> &GetLabelIdx(void);
- virtual void LoadDataSet(const std::string file_name) = 0;
+ virtual void LoadDataSet(const std::string file_name, unsigned int new_label_cnt) = 0;
+ virtual void AddDataSet(std::vector<float> &feature_vec, const unsigned int label_idx,
+ const unsigned int label_cnt) = 0;
};
#endif \ No newline at end of file
diff --git a/mv_machine_learning/training/include/feature_vector_manager.h b/mv_machine_learning/training/include/feature_vector_manager.h
index ad28063b..19d25a0d 100644
--- a/mv_machine_learning/training/include/feature_vector_manager.h
+++ b/mv_machine_learning/training/include/feature_vector_manager.h
@@ -28,7 +28,7 @@
typedef struct {
unsigned int signature;
size_t feature_size;
- size_t one_hot_table_size;
+ size_t label_cnt;
unsigned int data_set_cnt;
} FeaVecHeader;
@@ -50,9 +50,10 @@ public:
static void GetVecFromXRGB(unsigned char *in_data, std::vector<float> &vec, unsigned int in_width,
unsigned int in_height, unsigned int re_width, unsigned int re_height);
- virtual void WriteHeader(size_t feature_size, size_t one_hot_table_size, unsigned int data_set_cnt) = 0;
+ virtual void WriteHeader(size_t feature_size, size_t label_cnt, unsigned int data_set_cnt) = 0;
virtual void ReadHeader(FeaVecHeader &header) = 0;
- virtual void WriteFeatureVec(std::vector<float> &feature_vec, const int max_label, const int label_index) = 0;
+ virtual void StoreData(std::vector<std::vector<float> > &features_vec,
+ std::vector<unsigned int> &label_index) = 0;
virtual void Remove() = 0;
static constexpr unsigned int feature_vector_signature = 0xFEA09841;
diff --git a/mv_machine_learning/training/include/training_model.h b/mv_machine_learning/training/include/training_model.h
index 7a0f11fa..6d21294f 100644
--- a/mv_machine_learning/training/include/training_model.h
+++ b/mv_machine_learning/training/include/training_model.h
@@ -62,6 +62,7 @@ public:
void Compile();
void Train();
void RemoveModel();
+ void getWeights(float **weights, size_t *size, std::string name);
virtual void ConfigureModel(int num_of_class) = 0;
virtual TrainingEngineBackendInfo &GetTrainingEngineInfo() = 0;
diff --git a/mv_machine_learning/training/src/data_set_manager.cpp b/mv_machine_learning/training/src/data_set_manager.cpp
index e8952cce..75699333 100644
--- a/mv_machine_learning/training/src/data_set_manager.cpp
+++ b/mv_machine_learning/training/src/data_set_manager.cpp
@@ -19,7 +19,7 @@
using namespace std;
DataSetManager::DataSetManager()
- : _data(), _labels(), _label_index(), _feature_vector_size(), _label_size(), _data_set_length()
+ : _data(), _labels(), _label_index(), _feature_vector_size(), _label_count(), _data_set_length()
{}
DataSetManager::~DataSetManager()
@@ -68,9 +68,9 @@ size_t DataSetManager::GetFeaVecSize(void)
return _feature_vector_size;
}
-size_t DataSetManager::GetLabelSize(void)
+size_t DataSetManager::GetLabelCnt(void)
{
- return _label_size;
+ return _label_count;
}
size_t DataSetManager::GetDataSetLen(void)
diff --git a/mv_machine_learning/training/src/training_model.cpp b/mv_machine_learning/training/src/training_model.cpp
index ee022eda..6b11f05c 100644
--- a/mv_machine_learning/training/src/training_model.cpp
+++ b/mv_machine_learning/training/src/training_model.cpp
@@ -76,8 +76,6 @@ void TrainingModel::ApplyDataSet(unique_ptr<DataSetManager> &data_set)
throw InvalidOperation("Fail to add data to dataset.", ret);
}
- data_set->Clear();
-
ret = _training->SetDataset(_model.get(), _data_set.get());
if (ret != TRAINING_ENGINE_ERROR_NONE)
throw InvalidOperation("Fail to set dataset to model.", ret);
@@ -119,6 +117,14 @@ void TrainingModel::Train()
SaveModel(_internal_model_file);
}
+void TrainingModel::getWeights(float **weights, size_t *size, std::string name)
+{
+ int ret = _training->GetWeightWithLayer(_model.get(), weights, size, name);
+ if (ret != TRAINING_ENGINE_ERROR_NONE)
+ throw InvalidOperation("Fail to get weights.", ret);
+
+}
+
void TrainingModel::RemoveModel()
{
RemoveModel(_internal_model_file);