0
0
RhSolutions-ML/RhSolutions.ML.Builder/Program.cs
2023-09-20 08:18:22 +03:00

42 lines
2.0 KiB
C#

using Microsoft.ML;
namespace RhSolutions.ML.Builder
{
public class Program
{
private static string _appPath = Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]) ?? ".";
private static MLContext _mlContext = new MLContext(seed: 0);
public static void Main()
{
var _trainDataView = _mlContext.Data.LoadFromTextFile<Product>(
Path.Combine(_appPath, "..", "..", "..", "Data", "train.tsv"), hasHeader: true);
var pipeline = ProcessData();
BuildAndTrainModel(_trainDataView, pipeline, out ITransformer trainedModel);
SaveModelAsFile(_mlContext, _trainDataView.Schema, trainedModel);
}
private static IEstimator<ITransformer> ProcessData()
{
var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "Type", outputColumnName: "Label")
.Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: "Name", outputColumnName: "NameFeaturized"))
.Append(_mlContext.Transforms.Concatenate("Features", "NameFeaturized"))
.AppendCacheCheckpoint(_mlContext);
return pipeline;
}
private static IEstimator<ITransformer> BuildAndTrainModel(IDataView trainingDataView, IEstimator<ITransformer> pipeline, out ITransformer trainedModel)
{
var trainingPipeline = pipeline.Append(_mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features"))
.Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
trainedModel = trainingPipeline.Fit(trainingDataView);
return trainingPipeline;
}
private static void SaveModelAsFile(MLContext mlContext, DataViewSchema trainingDataViewSchema, ITransformer model)
{
mlContext.Model.Save(model, trainingDataViewSchema,
Path.Combine(_appPath, "..", "..", "..", "Models", "model.zip"));
}
}
}