Skip to main content

Introducing the ML.NET Text Classification API (preview)

ML.NET is an open-source, cross-platform machine learning framework for .NET developers that enables integration of custom machine learning models into .NET apps.

A few weeks ago we shared a blog post with updates of what we’ve been working on in ML.NET across the framework and tooling. Some of those updates included components of our deep-learning plan. An important part of that plan includes the introduction of scenario focused APIs in ML.NET.

After months of work and collaborations with TorchSharp and Microsoft Research, today we’re excited to announce the Text Classification API.

The Text Classification API is an API that makes it easier for you to train custom text classification models in ML.NET using the latest state-of-the-art deep learning techniques.

What is text classification?

Text classification as the name implies is the process of applying labels or categories to text.

Common use cases include:

  • Categorizing e-mail as spam or not spam
  • Analyzing sentiment as positive or negative from customer reviews
  • Applying labels to support tickets

Solving text classification with machine learning

Classification is a common problem in machine learning. There are a variety of algorithms you can use to train a classification model. Text classification is a subcategory of classification which deals specifically with raw text. Text poses interesting challenges because you have to account for the context and semantics in which the text occurs. As such, encoding meaning and context can be difficult. In recent years, deep learning models have emerged as a promising technique to solve natural language problems. More specifically, a type of neural network known as transformers has become the predominant way of solving natural language problems like text classification, translation, summarization, and question answering.

Transformers were introduced in the paper Attention is all you need. Some popular transformer architectures for natural language tasks include:

  • Bidirectional Encoder Representations from Transformers (BERT)
  • Robustly Optimized BERT Pretraining Approach (RoBERTa)
  • Generative Pre-trained Transformer 2 (GPT-2)
  • Generative Pre-trained Transformer 3 (GPT-3)

At a high level, transformers are a model architecture consisting of encoding and decoding layers. The encoder takes raw text as input and maps the input to a numerical representation (including context) to produce features. The decoder uses information from the encoder to produce output such as a category or label in the case of text classification. What makes these layers so special is the concept of attention. Attention is the idea of focusing on specific parts of an input based on the importance of their context in relation to other inputs in a sequence. For example, let’s say I’m categorizing news articles based on the headline. Not all words in the headline are relevant. In a headline like “Auto sales are at an all-time high”, a word like “sales” might get more attention and lead to labeling the article as business or finance.

High-level transformer network architecture

Like most neural networks, training transformers from scratch can be expensive because they require large amounts of data and compute. However, you don’t always have to train from scratch. Using a technique known as fine-tuning you can take a pre-trained model and retrain the layers specific to your domain or problem using your own data. This gives you the benefit of having a model that’s more tailored to solve your problem without having to go through the process of training the entire model from scratch.

The Text Classification API (preview)

Now that you have a general overview of how text classification problems can be solved using deep learning, let’s take a look at how we’ve incorporated many of these techniques into the Text Classification API.

ML.NET Text Classification API Architecture

The Text Classification API is powered by TorchSharp. TorchSharp is a .NET library that provides access to libtorch, the library that powers PyTorch. TorchSharp contains the building blocks for training neural networks from scratch in .NET. The TorchSharp components however are low-level and building neural networks from scratch has a steep learning curve. In ML.NET, we’ve abstracted some of that complexity to the scenario level.

In direct collaboration with Microsoft Research, we’ve taken a TorchSharp implementation of NAS-BERT, a variant of BERT obtained with neural architecture search, and added it to ML.NET. Using a pre-trained version of this model, the Text Classification API uses your data to fine-tune the model.

Get started with the Text Classification API

For a complete code sample of the Text Classification API, see the Text Classification API notebook.

The Text Classification API is part of the latest 2.0.0 and 0.20.0 preview versions of ML.NET.

To use it, you’ll have to install the following packages in addition to Microsoft.ML:

Use the NuGet package manager in Visual Studio or the dotnet CLI to install the packages

dotnet add package Microsoft.ML --prerelease
dotnet add package Microsoft.ML.TorchSharp --prerelease 

// If using CPU
dotnet add package TorchSharp-cpu

// If using GPU
// dotnet add package TorchSharp-cuda-windows
// dotnet add package TorchSharp-cuda-linux   

Then, reference the packages and use the Text Classification API in your pipeline.

//Reference packages
using Microsoft.ML;
using Microsoft.ML.TorchSharp;

// Initialize MLContext
var mlContext = new MLContext();

// Load your data
var reviews = new[]
{
    new {Text = "This is a bad steak", Sentiment = "Negative"},
    new {Text = "I really like this restaurant", Sentiment = "Positive"}
};

var reviewsDV = mlContext.Data.LoadFromEnumerable(reviews);

//Define your training pipeline
var pipeline =
        mlContext.Transforms.Conversion.MapValueToKey("Label", "Sentiment")
            .Append(mlContext.MulticlassClassification.Trainers.TextClassification(numberOfClasses: 2, sentence1ColumnName: "Text"))
            .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

// Train the model
var model = pipeline.Fit(reviewsDV);

For this sample, since there are only two classes (“Positive” and “Negative”), the numberOfClasses parameter is set to 2. The API supports up to two sentences as input each limited to 512 tokens. Typically one token maps to one word in a sentence. If the sentence is longer than 512 tokens, it’s automatically truncated for you. In this case, since there’s only one sentence, only the sentence1ColumnName is set.

The training produces an ML.NET model that you can use for inferencing using either the Transform method or PredictionEngine.

What’s next?

This is one of the first steps towards enabling natural language scenarios in ML.NET. There are still a few limitations when using the Text Classification API such as not being able to use the Evaluate method to calculate evaluation metrics. Based on your feedback, we plan to:

  • Make improvements to the Text Classification API
  • Introduce other scenario-based APIs

We want to hear from you. Help us prioritize and make these experiences the best they can be by providing feedback and raising issues in the dotnet/machinelearning GitHub repo.

Get started and resources

Learn more about ML.NET, Model Builder, and the ML.NET CLI in Microsoft Docs.

If you run into any issues, feature requests, or feedback, please file an issue in the ML.NET repo.

Join the ML.NET Community Discord or #machine-learning channel on the .NET Development Discord.

Tune in to the Machine Learning .NET Community Standup every other Wednesday at 10am Pacific Time.

The post Introducing the ML.NET Text Classification API (preview) appeared first on .NET Blog.



source https://devblogs.microsoft.com/dotnet/introducing-the-ml-dotnet-text-classification-api-preview/

Comments

Popular posts from this blog

Fake CVR Generator Denmark

What Is Danish CVR The Central Business Register (CVR) is the central register of the state with information on all Danish companies. Since 1999, the Central Business Register has been the authoritative register for current and historical basic data on all registered companies in Denmark. Data comes from the companies' own registrations on Virk Report. There is also information on associations and public authorities in the CVR. As of 2018, CVR also contains information on Greenlandic companies, associations and authorities. In CVR at Virk you can do single lookups, filtered searches, create extracts and subscriptions, and retrieve a wide range of company documents and transcripts. Generate Danish CVR For Test (Fake) Click the button below to generate the valid CVR number for Denmark. You can click multiple times to generate several numbers. These numbers can be used to Test your sofware application that uses CVR, or Testing CVR APIs that Danish Govt provide. Generate

How To Iterate Dictionary Object

Dictionary is a object that can store values in Key-Value pair. its just like a list, the only difference is: List can be iterate using index(0-n) but not the Dictionary . Generally when we try to iterate the dictionary we get below error: " Collection was modified; enumeration operation may not execute. " So How to parse a dictionary and modify its values?? To iterate dictionary we must loop through it's keys or key - value pair. Using keys

How To Append Data to HTML5 localStorage or sessionStorage?

The localStorage property allows you to access a local Storage object. localStorage is similar to sessionStorage. The only difference is that, while data stored in localStorage has no expiration time untill unless user deletes his cache, data stored in sessionStorage gets cleared when the originating window or tab get closed. These are new HTML5 objects and provide these methods to deal with it: The following snippet accesses the current domain's local Storage object and adds a data item to it using Storage.setItem() . localStorage.setItem('myFav', 'Taylor Swift'); or you can use the keyname directly as : localStorage.myFav = 'Taylor Swift'; To grab the value set in localStorage or sessionStorage, we can use localStorage.getItem("myFav"); or localStorage.myFav There's no append function for localStorage or sessionStorage objects. It's not hard to write one though.The simplest solution goes here: But we can kee