Multi-Tenancy / Global Filters in Entity Framework

In this article I am going to describe a method to add global filters in Entity framework. The version I used for this example is 6.1

What does it do?

Imagine that you have a multi-tenant database i.e. in a single database you have data for multiple tenants. The way to do this in general is to add a column for “tenantId” (or equivalent) in every table that requires to be tenant specific.

What can work (but not really)?

Every query that a developer writes must take into account tenancy, in Lambda expressions, complex linq queries with joins and all and to make sure that you do not run into a disaster situation where tenant Y’s data starts to show up for tenant X you will have to setup a strict review process where every new query written needs to be scrutinized (cough).

The Code

The first thing you need to make sure is that you DBContext class exposes IDbSet<> instead of DbSet<> e.g.

public class MyDataContext : DbContext { public IDbSet<MyEntity> MyEntities { get ; set ; } }

Then you need your own implementation of a DbSet which wraps over the functionality of the default one:

public class FilteredDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity>, IOrderedQueryable, IQueryable<TEntity>, IQueryable, IEnumerable<TEntity>, IEnumerable, IListSource where TEntity : class { private readonly DbSet<TEntity> _set; private readonly Action<TEntity> _initializeEntity; private readonly Expression<Func<TEntity, bool >> _filter; public FilteredDbSet (DbContext context) : this (context.Set<TEntity>(), i => true , null ) { } public FilteredDbSet (DbContext context, Expression<Func<TEntity, bool >> filter) : this (context.Set<TEntity>(), filter, null ) { } public FilteredDbSet (DbContext context, Expression<Func<TEntity, bool >> filter, Action<TEntity> initializeEntity) : this (context.Set<TEntity>(), filter, initializeEntity) { } public Expression<Func<TEntity, bool >> Filter => _filter; public IQueryable<TEntity> Include ( string path) { return _set.Include(path).Where(_filter).AsQueryable(); } private FilteredDbSet (DbSet<TEntity> set , Expression<Func<TEntity, bool >> filter, Action<TEntity> initializeEntity) { _set = set ; _filter = filter; MatchesFilter = filter.Compile(); _initializeEntity = initializeEntity; } public Func<TEntity, bool > MatchesFilter { get ; private set ; } public IQueryable<TEntity> Unfiltered () { return _set; } public void ThrowIfEntityDoesNotMatchFilter (TEntity entity) { if (!MatchesFilter(entity)) throw new ArgumentOutOfRangeException(); } public TEntity Add (TEntity entity) { DoInitializeEntity(entity); ThrowIfEntityDoesNotMatchFilter(entity); return _set.Add(entity); } public TEntity Attach (TEntity entity) { ThrowIfEntityDoesNotMatchFilter(entity); return _set.Attach(entity); } public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class , TEntity { var entity = _set.Create<TDerivedEntity>(); DoInitializeEntity(entity); return (TDerivedEntity)entity; } public TEntity Create () { var entity = _set.Create(); DoInitializeEntity(entity); return entity; } public TEntity Find ( params object [] keyValues) { var entity = _set.Find(keyValues); if (entity == null ) return null ; ThrowIfEntityDoesNotMatchFilter(entity); return entity; } public TEntity Remove (TEntity entity) { ThrowIfEntityDoesNotMatchFilter(entity); return _set.Remove(entity); } public ObservableCollection<TEntity> Local => _set.Local; IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator() { return _set.Where(_filter).GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return _set.Where(_filter).GetEnumerator(); } Type IQueryable.ElementType => typeof (TEntity); Expression IQueryable.Expression => _set.Where(_filter).Expression; IQueryProvider IQueryable.Provider => _set.AsQueryable().Provider; bool IListSource.ContainsListCollection => false ; IList IListSource.GetList() { throw new InvalidOperationException(); } private void DoInitializeEntity (TEntity entity) { _initializeEntity?.Invoke(entity); } public DbSqlQuery<TEntity> SqlQuery ( string sql, params object [] parameters) { return _set.SqlQuery(sql, parameters); } }

The two important points to note above is that we can now pass an expression which will act as a where clause and another expression which gets run while initializing (adding) an entity. Both the latter expressions combined ensures that:

A person will not get what does not belong to the person. A person cannot add an Entity into the realm of another tenant.

Now if we replace the DbSet with something like this:

this .MyEntities = new FilteredDbSet<MyEntity>( this , x => x.TenantId == _tenantId, x => x.TenantId = _tenantId);

we should be good to go.

There is one catch though, replacing the DbSet although seems (and ideally should be) simple, it really isn’t. If we try to do it while the DbContext initializes, e.g. its constructor, we get an error as the EF hasn’t yet prepared its own models, and no putting this in onModelCreating doesn’t help either as the actual model creation happens after that.

One solution is to do this via a database initializer, like so:

public sealed class TenancyInitializer<TContext> : IDatabaseInitializer<TContext> where TContext : DbContext, IMultiTenantContext { private readonly IDatabaseInitializer<TContext> _chainedInitializer; public bool AllowForcedCallsToInternalInitializer { get ; set ; } private static int _initializeAlreadyCalledFor = 0 ; public TenancyInitializer (IDatabaseInitializer<TContext> chainedInitializer, bool allowForcedCallsToInternalInitializer) { _chainedInitializer = chainedInitializer; AllowForcedCallsToInternalInitializer = allowForcedCallsToInternalInitializer; } public TenancyInitializer ( bool allowForcedCallsToInternalInitializer) { AllowForcedCallsToInternalInitializer = allowForcedCallsToInternalInitializer; } public TenancyInitializer (IDatabaseInitializer<TContext> chainedInitializer) { _chainedInitializer = chainedInitializer; } public TenancyInitializer () { } public void InitializeDatabase (TContext context) { context.ApplyTenancy(); var initializeAlreadyCalledFor = Interlocked.Exchange( ref _initializeAlreadyCalledFor, 1 ) == 0 ; if (_chainedInitializer != null && (!initializeAlreadyCalledFor || AllowForcedCallsToInternalInitializer)) _chainedInitializer.InitializeDatabase(context); } }

Note that in the example above we left room to add other database intializers for seeding and dropping/creating code-first related database initializers in case we need to use them as well.

Let’s define the Interface which we will apply to our DataContext class which will mark it to have a method for applying tenancy used by the initializer described above.

public interface IMultiTenantContext { void ApplyTenancy( string tenantId = null ); }

Below implementation is an example of how the “Schedule” Entity can have that filter applied, note the second expression which is used to intialize a Schedule entity whenever one is added.

public TenantDataContext () { this .Database.Initialize( true ); } public void ApplyTenancy ( string tenantId = null ) { if (tenantId != null ) _tenantId = tenantId; if (_tenantId == null ) { var currentContext = Context.AppContext.Current; if (currentContext == null ) throw new InvalidOperationException( $ "The '{nameof(Context.AppContext)}' was null, without it one tenant might see another tenants data." ); _tenantId = currentContext.TenantId; } this .Schedules = new FilteredDbSet<Schedule>( this , x => x.TenantId == _tenantId, x => x.TenantId = _tenantId); }

Lastly, for completion sake, here is the context that I was using in the examples above. Using CallContext to keep the context ensures that it will be available throughout the application and will remain unique for a request even if there is multi-threading (or async/await) happening in the code. But this is for convenience, you can use whatever other method of passing the tenant id to the “ApplyTenancy” method of the context above.

[Serializable] public sealed class AppContext : MarshalByRefObject, ILogicalThreadAffinative { public string TenantId { get ; } private const string Key = "bcc.adv.con.appCon" ; public static AppContext Current { get { return (AppContext) CallContext.LogicalGetData(Key); } private set { CallContext.LogicalSetData(Key, value ); } } public AppContext ( string tenantId) { TenantId = tenantId; } public void SetIntoContext () { if (Current != null ) throw new InvalidOperationException( "The context is meant to be immutable, an attempt was made to set the context with the context already having data." ); Current = this ; } }