It's time to harden security on your Spring Boot application!

One way to do that is by limiting failed logins by username or IP address. And if you haven't done that yet, I'll show you how here.

Some prereqs, though: you already need a fundamental understanding of Spring Security. You may have followed my rather archaic way of handling security with JWT back in the day

But rest assured the security system I use now has evolved quite a bit since then. Maybe I should update the article...

Anyhoo, you need at least a basic understanding of Spring Security to understand what's going on here. If you don't much about it, there are plenty of resources online.

And with that disclaimer out of the way, let's get started.

The Models

It all starts with the models.

First, recall that in Spring Security your User object must implement UserDetails. But you're free to add custom properties to that object as well.

And that's what you'll do here.

Add a couple of properties that reflect the number of failed login attempts and a timestamp reflecting the most recent failed login.

public abstract class BaseUser implements UserDetails {

...
  
    protected Integer failedLoginAttempts;
    protected Long lastFailedLoginTime;

...
}

Note that I'm using a BaseUser object here because I'm handling much of security in a separate dependency. But you can certainly add those properties to your concrete class.

Also, the last failed login time is stored as a Long object representing the number of milliseconds in the current epoch (or since 1970 began).

Add the appropriate getters and setters as well.

Okay. That will work for tracking failed login attempts by user.

But someone could theoretically login with different names from the same IP address in an attempt at a brute-force hack their way in. You probably want to nip that in the bud as well.

So you need a new model.

Call it IpLog. It's going to track the IP addresses of people who've attempted to log in (whether they're successful or not).

And truth be told, it's good practice to log the IP addresses of people who visit your application. At least if you're serious about security.

So here's what that class looks like:

public abstract class BaseIpLog {

    protected String ipAddress;
    protected String username;
    protected Long lastLoginAttempt;
    protected Boolean successfulLogin;

//getters and setters

}

And once again that's the base class, the implementation looks like this:

@Document(collection = "#{@environment.getProperty('mongo.iplog.collection')}")
public class IpLog extends BaseIpLog {

    @Id
    private String id;

//getters and setters

}

Why do I do it that way? To keep the implementation associated with the underlying database technology (in this case MongoDB) separate from the database-agnostic properties in the base class.

That's why.

And as you can tell from the @Document annotation, this object gets persisted into a MongoDB collection. 

That brings me to a point I should cover now...

Why Persistence?

You might have seen some other security solutions online that cache the failed login attempts rather than persisting them.

So why am I persisting them?

The short answer is because I don't want to give hackers enough leeway to take up previous server memory with their shenanigans.

The long answer is... well there is no long answer. That's pretty much it.

Plus, regarding the IP log, it's just a great idea to persist the IP addresses of folks who log in. You never know when you might need that info.

The Support System

You probably already have a repository and maybe even a service in place to handle persitence-related tasks for the User object. But not for the IpLog.

Fix that by creating the repo:

@Repository
public interface IpLogRepository extends MongoRepository<IpLog, String> {

}

Well that's easy enough.

But now create a service as well:

@Service
public class IpLogService implements IpTracker {
    
    private static final Logger LOG = LoggerFactory.getLogger(IpLogService.class);

    @Autowired
    private MongoTemplate mongoTemplate;

    @Autowired
    private IpLogRepository ipLogRepository;
    
    
    @Override
    public List<IpLog> fetchIpFailureRecord(String ipAddress, Long startingTime) {
        List<IpLog> list = new ArrayList<>();
        List<AggregationOperation> ops = new ArrayList<>();
        
        if (ipAddress != null) {
            AggregationOperation ipMatch = Aggregation.match(Criteria.where("ipAddress").is(ipAddress));
            ops.add(ipMatch);
            
            AggregationOperation dateThreshold = Aggregation.match(Criteria.where("lastLoginAttempt").gte(startingTime));
            ops.add(dateThreshold);

            AggregationOperation failMatch = Aggregation.match(Criteria.where("successfulLogin").is(false));
            ops.add(failMatch);
            
            Aggregation aggregation = Aggregation.newAggregation(ops);
            
            list = mongoTemplate.aggregate(aggregation, mongoTemplate.getCollectionName(IpLog.class), IpLog.class).getMappedResults();
        }        
        
        return list;
    }


    @Override
    public void successfulLogin(String username, String ipAddress) {
        IpLog ipLog = new IpLog();
        
        ipLog.setIpAddress(ipAddress);
        ipLog.setLastLoginAttempt(System.currentTimeMillis());
        ipLog.setSuccessfulLogin(true);
        ipLog.setUsername(username);
        
        ipLogRepository.save(ipLog);
    }
    
    
    @Override
    public void unsuccessfulLogin(String username, String ipAddress) {
        IpLog ipLog = new IpLog();
        
        ipLog.setIpAddress(ipAddress);
        ipLog.setLastLoginAttempt(System.currentTimeMillis());
        ipLog.setSuccessfulLogin(false);
        ipLog.setUsername(username);
        
        ipLogRepository.save(ipLog);
    }    
}

Okay. That's the first "big" class of this guide.

For starters, you don't need to implement IpTracker if you don't want to. I do it that way in my solution so that the implementation is ultimately available within the security dependency, which is where the interface is defined.

For now, just focus on what those three methods are doing.

The first method fetches the IP failure log. That's all the times the given IP address failed a login attempt within the specified timeframe.

Note that the starting point of the timeframe is a Long object that represents the number of milliseconds since 1970 began.

That method also accepts the IP address as a parameter.

The fetchIpFailureRecord() method uses MongoDB aggregation to retrieve the documents that match the criteria. I prefer to use services for this kind of thing rather than repositories with those awkwardly named methods or bizarre query annotations.

The other two methods persist successful and unsuccessful logins, respectively.

At Your Service

Next, update your implementation of UserDetailsService. This is where you'll add a method to record failed login attempts. 

Here's what that method looks like in my class:

    public void updateFailedLoginAttempts(String username) {
        try {
            UserDetails userDetails = loadUserByUsername(username);
            User user = (User)userDetails;
            
            Integer failedLoginAttempts = user.getFailedLoginAttempts();
            if (failedLoginAttempts == null) {
                failedLoginAttempts = 1;
            } else {
                failedLoginAttempts++;
            }
            
            user.setFailedLoginAttempts(failedLoginAttempts);
            user.setLastFailedLoginTime(System.currentTimeMillis());
            
            ((UserRepository)userDetailsRepository).save(user);
        } catch (UsernameNotFoundException e) {
            LOG.error("Problem attempting to update failed login attempts!", e);
        }
    }

Pretty straightforward stuff. It loads the User object by username. Then it increments the failed login attempts as well as the last failed login time (which is right now). Finally, it saves the user.

Forgive that explicit casting you see with the save() method. You probably won't need to do anythiing like that but I need to do it because, once again, the interface is in a dependency that doesn't know anything about the implementation-specific save() method.

While you're in that UserDetailsService implementation, add a couple of more methods:

    public void successfulLogin(String username) {
        resetFailedLoginAttempts(username);
    }
    
    
    private void resetFailedLoginAttempts(String username) {
        UserDetails userDetails = loadUserByUsername(username);
        User user = (User)userDetails;
        
        Integer failedLoginAttempts = user.getFailedLoginAttempts();
        if (failedLoginAttempts != null) {
            user.setFailedLoginAttempts(null);
            ((UserRepository)userDetailsRepository).save(user);
        }
    }

The public method there is what external objects will use as a callback to indicate that the user logged in successfully.

The private method does the work of resetting the user's failed login attempts to null. This application only counts consecutive failed login attempts. So when the user succesfully gets in, the counter gets reset.

An Exception to Every Rule

Next, you're going to need some failure-specific exceptions:

  • An exception for too many failed logins from the same IP
  • An exception for too many failed logins from a single user

Both of those exceptions should extend AuthenticationException because that is, in fact, the point of AuthenticationException. It's an abstract class you extend with details specific to a particular type of authentication failure.

I created one class called TooManyFailedIpLoginsAuthenticationException and another called TooManyFailedLoginsAuthenticationException

The code is ridiculously simple:

public class TooManyFailedLoginsAuthenticationException extends AuthenticationException {
    
    private static final long serialVersionUID = 5368673516685167890L;

    public TooManyFailedLoginsAuthenticationException(String s) {
        super(s);
    }
}
public class TooManyFailedIpLoginsAuthenticationException extends AuthenticationException {
    
    private static final long serialVersionUID = -6313473860143052407L;

    public TooManyFailedIpLoginsAuthenticationException(String s) {
        super(s);
    }
}

The only difference between those two classes is the name.

But the name is important because it defines the type of authentication failure. I'll show you how I use it in a moment.

A Utility You Can Utilize

Next, create a new utility class that keeps your code neat and separates concerns. Call it LoginAttemptsUtil.

public class LoginAttemptsUtil {

    private static final Logger LOG = LoggerFactory.getLogger(LoginAttemptsUtil.class);
    
    private static final int MAX_FAILED_LOGINS = 4;
    private static final long FAILED_LOGIN_TIMEOUT_PERIOD = DateConversionUtil.NUMBER_OF_MILLISECONDS_IN_DAY;

    
    private JwtUserDetailsService jwtUserDetailsService;
    private IpTracker ipTracker;
    
    
    public LoginAttemptsUtil(JwtUserDetailsService jwtUserDetailsService, IpTracker ipTracker) {
        this.jwtUserDetailsService = jwtUserDetailsService;
        this.ipTracker = ipTracker;
    }
    
    
    public void checkMaxLoginAttempts(JwtRequest jwtRequest) {
        LOG.debug("Checking for too many failed logins");
        
        if (jwtRequest != null && jwtRequest.getUsername() != null) {
            BaseUser user = (BaseUser)jwtUserDetailsService.loadUserByUsername(jwtRequest.getUsername());        
            checkForFailedLogins(user);
        } else {
            throw new UserServiceAuthenticationException("Can't parse login request!");
        }
    }
    
    
    private void checkForFailedLogins(BaseUser user) {
        if (user.getFailedLoginAttempts() != null) {
            if (user.getFailedLoginAttempts() >= MAX_FAILED_LOGINS) {
                checkDateThreshold(user);
            }
        }
    }
    
    
    private void checkDateThreshold(BaseUser user) {
        if (user.getLastFailedLoginTime() != null) {
            Long now = System.currentTimeMillis();
            Long difference = now - user.getLastFailedLoginTime();
            
            if (difference < FAILED_LOGIN_TIMEOUT_PERIOD) {
                throw new TooManyFailedLoginsAuthenticationException("Too many failed logins!");
            }
        }
    }
    
    
    /**
     * Check to make sure this user hasn't failed authentication too many times
     * from the same IP address.
     */
    public void checkIpValidity(HttpServletRequest request) {
        String ipAddress = request.getRemoteAddr();
        
        //timeframe in the past 24 hours
        Long timeframe = System.currentTimeMillis() - FAILED_LOGIN_TIMEOUT_PERIOD;
        
        List<? extends BaseIpLog> list = ipTracker.fetchIpFailureRecord(ipAddress, timeframe);
        if (list != null && list.size() >= MAX_FAILED_LOGINS) {
            throw new TooManyFailedIpLoginsAuthenticationException("Too many failed logins from this IP address!");
        }
    }
}

So I'm hardcoding the maximum number of failed login attempts and the time threshold (one day). You can put those in a .properties or .yml file if you prefer, but I don't think they'll change too much.

Next, note that the class is dependent on two services:

In your solution, you can probably dependency-inject those objects. But because I'm using a library and one is abstract while the other is an interface, I add them to the object the old-fashioned way: via a constructor.

The first method, checkMaxLoginAttempts(), checks for max login attempts by user.

The next two methods support that first method.

The last method, checkIpValidity(),  checks for max login attempts by ip address.

Note that the methods in the class above don't actually return anything. Instead, they just throw exceptions when something is wrong.

Okay. Now that you've got your support classes in place, it's time to use them.

Fiddling With the Filter

To handle user login by name and password, I use a class that extends Spring Security's UsernamePasswordAuthenticationFilter. It's called CredentialsAuthenticationFilter.

I won't get into everything associated with that class here (remember the prereqs), but here's what the attemptAuthentication() method looks like:

    @Override
    public Authentication attemptAuthentication(HttpServletRequest req, HttpServletResponse res) throws AuthenticationException {
        JwtRequest jwtRequest = null;
        ObjectMapper mapper = new ObjectMapper();
        LoginAttemptsUtil loginAttemptsUtil = new LoginAttemptsUtil(jwtUserDetailsService, ipTracker);
        
        try {
            //make sure the user hasn't failed login too many times from this IP address
            loginAttemptsUtil.checkIpValidity(req);            

            //construct the JwtRequest object from the input stream
            jwtRequest = mapper.readValue(req.getInputStream(), JwtRequest.class);

            //now check to make sure this user hasn't had too many failed login attempts
            loginAttemptsUtil.checkMaxLoginAttempts(jwtRequest);
            
            //handle login
            return handleLogin(jwtRequest);
        } catch (BadCredentialsException e) {
            LOG.error("Bad credentials!", e);
            
            //gotta log to both the user service and ip tracker
            //because the user service tracks failed login attempts per user
            //while the ip tracker tracks failed login attempts per ip
            jwtUserDetailsService.updateFailedLoginAttempts(jwtRequest.getUsername());
            ipTracker.unsuccessfulLogin(jwtRequest.getUsername(), req.getRemoteAddr());
            
            throw new InvalidCredentialsAuthenticationException(e.getMessage());
        } catch (JsonMappingException e) {
            LOG.error("Problem logging in user with credentials!", e);
            throw new UserServiceAuthenticationException(e.getMessage());
        } catch (IOException e) {
            LOG.error("Problem logging in user with credentials!", e);
            throw new UserServiceAuthenticationException(e.getMessage());
        }
    }

Focus on what's happening inside the try block.

First, the code checks to see if the user's IP address has tthat oo many recent failed logins. If so, then the checkIpValidity() method will throw TooManyFailedIpLoginsAuthenticationException. and the login process will stop right there. The attemptAuthentication() method will throw that exception as well.

If no exception is thrown, then the person hasn't failed authentication too many times from that IP address.

Next, the code creates a JwtRequest object from the HTTP request's input stream. That's a simple class that holds the username and password.

Well now that the code knows the username, it can use that info to check to see if this user has failed login too many times.

And that's what the checkMaxLoginAttempts() method will do.

Once again: that method doesn't return anything. It just throws an exception if this user has too many recent failed logins.

Next, pay attention to the first catch block. That catches BadCredentialsException which is the exception that unsurprisingly gets thrown when the user enters bad credentials.

Well once that happens the code needs to log that failed login attempt. And it does that with the aid of those two services I mentioned earlier:

The first service logs the failed login attempt from a user perspective.

The second service logs the failed login attempt from an IP perspective.

That first catch block also throws a custom exception: InvalidCredentialsAuthenticationException. But you could just as easily rethrow BadCredentialsException.

Prepping for Failure

If you're familiar with Spring Security (and you should be, see the prereqs again), then you already know about the AuthenticationFailureHandler interface. It's your job to implement that interface with a class that handles authentication failures.

I handle that implementation with a lambda expression in my WebSecurityConfigurerAdapter extension.

Here's what the method looks like in my class:

    protected AuthenticationFailureHandler authenticationFailureHandler() {
        return (request, response, ex) -> {
            if (ex instanceof InvalidCredentialsAuthenticationException) {
                ResponseUtil.invalidCredentials(response);
            } else if (ex instanceof TooManyFailedIpLoginsAuthenticationException) {
                ResponseUtil.tooManyFailedIpLogins(response);
            } else if (ex instanceof TooManyFailedLoginsAuthenticationException) {
                ResponseUtil.tooManyFailedLogins(response);
            } else {
                response.setStatus(HttpStatus.UNAUTHORIZED.value());
                ResponseWriterUtil.writeResponse(response, ex.getMessage(), ResponseStatusCode.UNAUTHORIZED);                                  
            }
        };
    }

The executive summary of what's going on there is that the implementation examines the exception that brought it to that point. Then, it creates a response based on the exception type.

You can check out ResponseUtil and ResponseWriterUtil to see more about how the response gets written. The bottom line here is that it provides a user-friendly message like "This IP address failed login in too many times today."

That message can get put in a JSON response where it will show up in a tool like Postman. It can also get put in the UI if a client-side app is making the call.

And don't forget to set that authentication failure handler in your CredentialsAuthenticationFilter object like so:

    protected CredentialsAuthenticationFilter credentialsAuthenticationFilter() throws Exception {
        CredentialsAuthenticationFilter filter = new CredentialsAuthenticationFilter(authenticationManager());
        filter.setAuthenticationFailureHandler(authenticationFailureHandler());
        //other stuff
        return filter;
    }

And you'll reference that filter in security configuration:

    @Override
    protected void configure(HttpSecurity httpSecurity) throws Exception {              
        httpSecurity
            .cors().and()
            .csrf().disable()
            .addFilter(bearerTokenAuthenticationFilter())
            .addFilter(credentialsAuthenticationFilter())
            .authorizeRequests()
            .anyRequest().hasAnyAuthority(allowedAuthorities).and()
            .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS);
    }

You can see the whole implementation in my CredentialsAndJwtSecurityConfig class.

But What If It's a Success?

Head on back to the CredentialsAuthenticationFilter class. Here's what I've done with the overridden successfulAuthentication() method:

    @Override
    protected void successfulAuthentication(HttpServletRequest req, HttpServletResponse res, 
            FilterChain chain, Authentication auth) throws IOException {
        
        final BaseUser user = (BaseUser)auth.getPrincipal();
        final String token = jwtUtil.generateToken(user);
        final Long expirationDate = jwtUtil.getExpirationDateFromToken(token).getTime();

        //log a successful authentication to iplog collection
        ipTracker.successfulLogin(user.getUsername(), req.getRemoteAddr());
        
        //log a successful authentication to user collection
        jwtUserDetailsService.successfulLogin(user.getUsername());
        
        JwtResponse jwtResponse = new JwtResponse(token, user, expirationDate);
        
        String body = new ObjectMapper().writeValueAsString(jwtResponse);
        LOG.debug("Body response is " + body);
        
        res.getWriter().write(body);
        res.getWriter().flush();
    }

The point of that method is to construct the JSON web token (JWT) and send it back to the user.

But before it does that, it invokes the successfulLogin() methods on both of the services you've already looked at. Those methods "reset" the user and the IP address back to 0 failed logins.

The successfulLogin() method on ipTracker also persists a MongoDB document that reflects a successful login from the IP address of the current user.

What About Filters and Interceptors?

You might be wondering why I don't handle this solution with either a filter or an interceptor. Let me take a few moments to answer that.

First, you could do it with an implementation of HandlerInterceptor. I decided to go in a different direction because I had problems with Spring Boot firing the interceptor.

And I'm not the only one. Others have experienced similar issues and they've fixed them by just restructuring how they declare the interceptor.

That solution didn't work for me so I decided not to use it. 

Regarding filters: I am using a filter. It's the security-level filter called CredentialsAuthenticationFilter.

But if you're thinking about a custom OncePerRequestFilter, it would work well here except it doesn't play nicely with Spring's implementation of CORS (Cross Origin Request Sharing). I found that my Angular app was getting CORS failures when the filter returned an error response.

I'm not a huge fan of Spring's CORS implementation as it stands. And I needed to dump it anyway because it's difficult to specify origins for service-to-service interaction in a Kubernetes environment.

So I pulled the plug on the filter solution and went with what I showed you here.

Wrapping It Up

That should pretty much do it.

Undoubtedly, you'll have to adapt what you've seen here to suit your own security solution. But at least you're headed in the right direction.

Have fun!

Photo by Download a pic Donate a buck! ^ from Pexels